app / src-tauri /src /proxy /tests /security_ip_tests.rs
AZILS's picture
Upload 323 files
a21c316 verified
//! IP Security Module Tests
//! IP 安全监控功能的综合测试套件
//!
//! 测试目标:
//! 1. 验证 IP 黑/白名单功能的正确性
//! 2. 验证 CIDR 匹配逻辑
//! 3. 验证过期时间处理
//! 4. 验证不影响主流程性能
//! 5. 验证数据库操作的原子性和一致性
#[cfg(test)]
mod security_db_tests {
use crate::modules::security_db::{
self, IpAccessLog, IpBlacklistEntry, IpWhitelistEntry,
init_db, add_to_blacklist, remove_from_blacklist, get_blacklist,
is_ip_in_blacklist, get_blacklist_entry_for_ip,
add_to_whitelist, remove_from_whitelist, get_whitelist,
is_ip_in_whitelist, save_ip_access_log, get_ip_access_logs,
get_ip_stats, cleanup_old_ip_logs, clear_ip_access_logs,
};
use std::time::{SystemTime, UNIX_EPOCH};
/// 辅助函数:获取当前时间戳
fn now_timestamp() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64
}
/// 辅助函数:清理测试环境
fn cleanup_test_data() {
// 清理黑名单
if let Ok(entries) = get_blacklist() {
for entry in entries {
let _ = remove_from_blacklist(&entry.id);
}
}
// 清理白名单
if let Ok(entries) = get_whitelist() {
for entry in entries {
let _ = remove_from_whitelist(&entry.id);
}
}
// 清理访问日志
let _ = clear_ip_access_logs();
}
// ============================================================================
// 测试类别 1: 数据库初始化
// ============================================================================
#[test]
fn test_db_initialization() {
// 验证数据库初始化不会 panic
let result = init_db();
assert!(result.is_ok(), "Database initialization should succeed: {:?}", result.err());
}
#[test]
fn test_db_multiple_initializations() {
// 验证多次初始化不会出错 (幂等性)
for _ in 0..3 {
let result = init_db();
assert!(result.is_ok(), "Multiple DB initializations should be idempotent");
}
}
// ============================================================================
// 测试类别 2: IP 黑名单基本操作
// ============================================================================
#[test]
fn test_blacklist_add_and_check() {
let _ = init_db();
cleanup_test_data();
// 添加 IP 到黑名单
let result = add_to_blacklist("192.168.1.100", Some("Test block"), None, "test");
assert!(result.is_ok(), "Should add IP to blacklist: {:?}", result.err());
// 验证 IP 在黑名单中
let is_blocked = is_ip_in_blacklist("192.168.1.100");
assert!(is_blocked.is_ok());
assert!(is_blocked.unwrap(), "IP should be in blacklist");
// 验证其他 IP 不在黑名单中
let is_other_blocked = is_ip_in_blacklist("192.168.1.101");
assert!(is_other_blocked.is_ok());
assert!(!is_other_blocked.unwrap(), "Other IP should not be in blacklist");
cleanup_test_data();
}
#[test]
fn test_blacklist_remove() {
let _ = init_db();
cleanup_test_data();
// 添加 IP
let entry = add_to_blacklist("10.0.0.5", Some("Temp block"), None, "test").unwrap();
// 验证存在
assert!(is_ip_in_blacklist("10.0.0.5").unwrap());
// 移除
let remove_result = remove_from_blacklist(&entry.id);
assert!(remove_result.is_ok());
// 验证已移除
assert!(!is_ip_in_blacklist("10.0.0.5").unwrap());
cleanup_test_data();
}
#[test]
fn test_blacklist_get_entry_details() {
let _ = init_db();
cleanup_test_data();
// 添加带有详细信息的条目
let _ = add_to_blacklist(
"172.16.0.50",
Some("Abuse detected"),
Some(now_timestamp() + 3600), // 1小时后过期
"admin",
);
// 获取条目详情
let entry_result = get_blacklist_entry_for_ip("172.16.0.50");
assert!(entry_result.is_ok());
let entry = entry_result.unwrap();
assert!(entry.is_some());
let entry = entry.unwrap();
assert_eq!(entry.ip_pattern, "172.16.0.50");
assert_eq!(entry.reason.as_deref(), Some("Abuse detected"));
assert_eq!(entry.created_by, "admin");
assert!(entry.expires_at.is_some());
cleanup_test_data();
}
// ============================================================================
// 测试类别 3: CIDR 匹配
// ============================================================================
#[test]
fn test_cidr_matching_basic() {
let _ = init_db();
cleanup_test_data();
// 添加 CIDR 范围到黑名单
let _ = add_to_blacklist("192.168.1.0/24", Some("Block subnet"), None, "test");
// 验证该子网内的 IP 都被阻止
assert!(is_ip_in_blacklist("192.168.1.1").unwrap(), "192.168.1.1 should match /24");
assert!(is_ip_in_blacklist("192.168.1.100").unwrap(), "192.168.1.100 should match /24");
assert!(is_ip_in_blacklist("192.168.1.254").unwrap(), "192.168.1.254 should match /24");
// 验证子网外的 IP 不被阻止
assert!(!is_ip_in_blacklist("192.168.2.1").unwrap(), "192.168.2.1 should not match");
assert!(!is_ip_in_blacklist("10.0.0.1").unwrap(), "10.0.0.1 should not match");
cleanup_test_data();
}
#[test]
fn test_cidr_matching_various_masks() {
let _ = init_db();
cleanup_test_data();
// 测试 /16 掩码
let _ = add_to_blacklist("10.10.0.0/16", Some("Block /16"), None, "test");
assert!(is_ip_in_blacklist("10.10.0.1").unwrap(), "Should match /16");
assert!(is_ip_in_blacklist("10.10.255.255").unwrap(), "Should match /16");
assert!(!is_ip_in_blacklist("10.11.0.1").unwrap(), "Should not match /16");
cleanup_test_data();
// 测试 /32 掩码 (单个 IP)
let _ = add_to_blacklist("8.8.8.8/32", Some("Block single"), None, "test");
assert!(is_ip_in_blacklist("8.8.8.8").unwrap(), "Should match /32");
assert!(!is_ip_in_blacklist("8.8.8.9").unwrap(), "Should not match /32");
cleanup_test_data();
}
#[test]
fn test_cidr_edge_cases() {
let _ = init_db();
cleanup_test_data();
// 测试 /0 (所有 IP) - 边界情况
let _ = add_to_blacklist("0.0.0.0/0", Some("Block all"), None, "test");
assert!(is_ip_in_blacklist("1.2.3.4").unwrap(), "Everything should match /0");
assert!(is_ip_in_blacklist("255.255.255.255").unwrap(), "Everything should match /0");
cleanup_test_data();
// 测试 /8 掩码
let _ = add_to_blacklist("10.0.0.0/8", Some("Block /8"), None, "test");
assert!(is_ip_in_blacklist("10.255.255.255").unwrap(), "Should match /8");
assert!(!is_ip_in_blacklist("11.0.0.0").unwrap(), "Should not match /8");
cleanup_test_data();
}
// ============================================================================
// 测试类别 4: 过期时间处理
// ============================================================================
#[test]
fn test_blacklist_expiration() {
let _ = init_db();
cleanup_test_data();
// 添加一个已过期的条目
let _ = add_to_blacklist(
"expired.test.ip",
Some("Already expired"),
Some(now_timestamp() - 60), // 1分钟前过期
"test",
);
// 过期条目应该被自动清理
let is_blocked = is_ip_in_blacklist("expired.test.ip");
// 注意:取决于实现,过期条目可能在查询时被清理
// 根据 security_db.rs 的实现,get_blacklist_entry_for_ip 会先清理过期条目
assert!(!is_blocked.unwrap(), "Expired entry should be cleaned up");
cleanup_test_data();
}
#[test]
fn test_blacklist_not_yet_expired() {
let _ = init_db();
cleanup_test_data();
// 添加一个未过期的条目
let _ = add_to_blacklist(
"not.expired.ip",
Some("Will expire later"),
Some(now_timestamp() + 3600), // 1小时后过期
"test",
);
// 未过期条目应该仍然生效
assert!(is_ip_in_blacklist("not.expired.ip").unwrap());
cleanup_test_data();
}
#[test]
fn test_permanent_blacklist() {
let _ = init_db();
cleanup_test_data();
// 添加永久封禁 (无过期时间)
let _ = add_to_blacklist(
"permanent.block.ip",
Some("Permanent ban"),
None, // 无过期时间
"test",
);
// 永久封禁应该始终生效
assert!(is_ip_in_blacklist("permanent.block.ip").unwrap());
cleanup_test_data();
}
// ============================================================================
// 测试类别 5: IP 白名单
// ============================================================================
#[test]
fn test_whitelist_add_and_check() {
let _ = init_db();
cleanup_test_data();
// 添加 IP 到白名单
let result = add_to_whitelist("10.0.0.1", Some("Trusted server"));
assert!(result.is_ok());
// 验证 IP 在白名单中
assert!(is_ip_in_whitelist("10.0.0.1").unwrap());
assert!(!is_ip_in_whitelist("10.0.0.2").unwrap());
cleanup_test_data();
}
#[test]
fn test_whitelist_cidr() {
let _ = init_db();
cleanup_test_data();
// 添加 CIDR 范围到白名单
let _ = add_to_whitelist("192.168.0.0/16", Some("Internal network"));
// 验证子网内的 IP 都被允许
assert!(is_ip_in_whitelist("192.168.1.1").unwrap());
assert!(is_ip_in_whitelist("192.168.255.255").unwrap());
// 验证子网外的 IP 不在白名单
assert!(!is_ip_in_whitelist("10.0.0.1").unwrap());
cleanup_test_data();
}
// ============================================================================
// 测试类别 6: IP 访问日志
// ============================================================================
#[test]
fn test_access_log_save_and_retrieve() {
let _ = init_db();
cleanup_test_data();
// 保存访问日志
let log = IpAccessLog {
id: uuid::Uuid::new_v4().to_string(),
client_ip: "test.log.ip".to_string(),
timestamp: now_timestamp(),
method: Some("POST".to_string()),
path: Some("/v1/messages".to_string()),
user_agent: Some("TestClient/1.0".to_string()),
status: Some(200),
duration: Some(150),
api_key_hash: Some("hash123".to_string()),
blocked: false,
block_reason: None,
username: None,
};
let save_result = save_ip_access_log(&log);
assert!(save_result.is_ok(), "Should save access log: {:?}", save_result.err());
// 检索日志
let logs = get_ip_access_logs(10, 0, Some("test.log.ip"), false);
assert!(logs.is_ok());
let logs = logs.unwrap();
assert!(!logs.is_empty(), "Should retrieve saved log");
assert_eq!(logs[0].client_ip, "test.log.ip");
cleanup_test_data();
}
#[test]
fn test_access_log_blocked_filter() {
let _ = init_db();
cleanup_test_data();
// 保存正常日志
let normal_log = IpAccessLog {
id: uuid::Uuid::new_v4().to_string(),
client_ip: "normal.access.ip".to_string(),
timestamp: now_timestamp(),
method: Some("GET".to_string()),
path: Some("/healthz".to_string()),
user_agent: None,
status: Some(200),
duration: Some(10),
api_key_hash: None,
blocked: false,
block_reason: None,
username: None,
};
let _ = save_ip_access_log(&normal_log);
// 保存被阻止的日志
let blocked_log = IpAccessLog {
id: uuid::Uuid::new_v4().to_string(),
client_ip: "blocked.access.ip".to_string(),
timestamp: now_timestamp(),
method: Some("POST".to_string()),
path: Some("/v1/messages".to_string()),
user_agent: None,
status: Some(403),
duration: Some(0),
api_key_hash: None,
blocked: true,
block_reason: Some("IP in blacklist".to_string()),
username: None,
};
let _ = save_ip_access_log(&blocked_log);
// 只检索被阻止的日志
let blocked_only = get_ip_access_logs(10, 0, None, true).unwrap();
assert_eq!(blocked_only.len(), 1);
assert_eq!(blocked_only[0].client_ip, "blocked.access.ip");
assert!(blocked_only[0].blocked);
cleanup_test_data();
}
// ============================================================================
// 测试类别 7: 统计功能
// ============================================================================
#[test]
fn test_ip_stats() {
let _ = init_db();
cleanup_test_data();
// 添加一些测试数据
for i in 0..5 {
let log = IpAccessLog {
id: uuid::Uuid::new_v4().to_string(),
client_ip: format!("stats.test.{}", i % 3), // 3 个唯一 IP
timestamp: now_timestamp(),
method: Some("POST".to_string()),
path: Some("/v1/messages".to_string()),
user_agent: None,
status: Some(200),
duration: Some(100),
api_key_hash: None,
blocked: i == 4, // 最后一个被阻止
block_reason: if i == 4 { Some("Test".to_string()) } else { None },
username: None,
};
let _ = save_ip_access_log(&log);
}
// 添加黑名单和白名单条目
let _ = add_to_blacklist("stats.black.1", None, None, "test");
let _ = add_to_blacklist("stats.black.2", None, None, "test");
let _ = add_to_whitelist("stats.white.1", None);
// 获取统计
let stats = get_ip_stats();
assert!(stats.is_ok());
let stats = stats.unwrap();
assert!(stats.total_requests >= 5, "Should have at least 5 requests");
assert!(stats.unique_ips >= 3, "Should have at least 3 unique IPs");
assert!(stats.blocked_count >= 1, "Should have at least 1 blocked request");
assert_eq!(stats.blacklist_count, 2);
assert_eq!(stats.whitelist_count, 1);
cleanup_test_data();
}
// ============================================================================
// 测试类别 8: 清理功能
// ============================================================================
#[test]
fn test_cleanup_old_logs() {
let _ = init_db();
cleanup_test_data();
// 添加一条 "旧" 日志 (模拟 2 天前)
let old_log = IpAccessLog {
id: uuid::Uuid::new_v4().to_string(),
client_ip: "old.log.ip".to_string(),
timestamp: now_timestamp() - (2 * 24 * 3600), // 2 天前
method: Some("GET".to_string()),
path: Some("/old".to_string()),
user_agent: None,
status: Some(200),
duration: Some(10),
api_key_hash: None,
blocked: false,
block_reason: None,
username: None,
};
let _ = save_ip_access_log(&old_log);
// 添加一条新日志
let new_log = IpAccessLog {
id: uuid::Uuid::new_v4().to_string(),
client_ip: "new.log.ip".to_string(),
timestamp: now_timestamp(),
method: Some("GET".to_string()),
path: Some("/new".to_string()),
user_agent: None,
status: Some(200),
duration: Some(10),
api_key_hash: None,
blocked: false,
block_reason: None,
username: None,
};
let _ = save_ip_access_log(&new_log);
// 清理 1 天前的日志
let deleted = cleanup_old_ip_logs(1);
assert!(deleted.is_ok());
assert!(deleted.unwrap() >= 1, "Should delete at least 1 old log");
// 验证新日志仍然存在
let logs = get_ip_access_logs(10, 0, Some("new.log.ip"), false).unwrap();
assert!(!logs.is_empty(), "New log should still exist");
// 验证旧日志已被清理
let old_logs = get_ip_access_logs(10, 0, Some("old.log.ip"), false).unwrap();
assert!(old_logs.is_empty(), "Old log should be cleaned up");
cleanup_test_data();
}
// ============================================================================
// 测试类别 9: 并发安全性
// ============================================================================
#[test]
fn test_concurrent_access() {
use std::thread;
let _ = init_db();
cleanup_test_data();
let handles: Vec<_> = (0..10)
.map(|i| {
thread::spawn(move || {
// 每个线程添加不同的 IP
let ip = format!("concurrent.test.{}", i);
let _ = add_to_blacklist(&ip, Some("Concurrent test"), None, "test");
// 验证自己添加的 IP
is_ip_in_blacklist(&ip).unwrap_or(false)
})
})
.collect();
let results: Vec<bool> = handles.into_iter().map(|h| h.join().unwrap()).collect();
// 所有线程都应该成功
assert!(results.iter().all(|&r| r), "All concurrent adds should succeed");
cleanup_test_data();
}
// ============================================================================
// 测试类别 10: 边界情况和错误处理
// ============================================================================
#[test]
fn test_duplicate_blacklist_entry() {
let _ = init_db();
cleanup_test_data();
// 第一次添加应该成功
let result1 = add_to_blacklist("duplicate.test.ip", Some("First"), None, "test");
assert!(result1.is_ok());
// 第二次添加相同 IP 应该失败 (UNIQUE constraint)
let result2 = add_to_blacklist("duplicate.test.ip", Some("Second"), None, "test");
assert!(result2.is_err(), "Duplicate IP should fail");
cleanup_test_data();
}
#[test]
fn test_empty_ip_pattern() {
let _ = init_db();
cleanup_test_data();
// 空 IP 模式应该仍然可以添加 (取决于业务需求)
// 这里只测试不会 panic
let result = add_to_blacklist("", Some("Empty IP"), None, "test");
// 结果可能成功或失败,但不应该 panic
let _ = result;
cleanup_test_data();
}
#[test]
fn test_special_characters_in_reason() {
let _ = init_db();
cleanup_test_data();
// 测试包含特殊字符的原因
let reason = "Test with 'quotes' and \"double quotes\" and emoji 🚫";
let result = add_to_blacklist("special.char.test", Some(reason), None, "test");
assert!(result.is_ok());
let entry = get_blacklist_entry_for_ip("special.char.test").unwrap().unwrap();
assert_eq!(entry.reason.as_deref(), Some(reason));
cleanup_test_data();
}
#[test]
fn test_hit_count_increment() {
let _ = init_db();
cleanup_test_data();
// 添加一个黑名单条目
let _ = add_to_blacklist("hit.count.test", Some("Count test"), None, "test");
// 多次查询应该增加 hit_count
for _ in 0..5 {
let _ = get_blacklist_entry_for_ip("hit.count.test");
}
// 检查 hit_count
let blacklist = get_blacklist().unwrap();
let entry = blacklist.iter().find(|e| e.ip_pattern == "hit.count.test");
assert!(entry.is_some());
assert!(entry.unwrap().hit_count >= 5, "Hit count should be at least 5");
cleanup_test_data();
}
}
// ============================================================================
// IP Filter 中间件测试 (单元测试)
// ============================================================================
#[cfg(test)]
mod ip_filter_middleware_tests {
// 注意:中间件测试需要模拟 HTTP 请求,这里提供测试框架
// 实际的集成测试应该在启动完整服务后进行
/// 验证 IP 提取逻辑的正确性
#[test]
fn test_ip_extraction_priority() {
// X-Forwarded-For 应该优先于 X-Real-IP
// X-Real-IP 应该优先于 ConnectInfo
// 这里只验证逻辑概念,实际测试需要构造 HTTP 请求
// 场景 1: X-Forwarded-For 有多个 IP,取第一个
let xff_header = "203.0.113.1, 198.51.100.2, 192.0.2.3";
let first_ip = xff_header.split(',').next().unwrap().trim();
assert_eq!(first_ip, "203.0.113.1");
// 场景 2: 单个 IP
let single_ip = "10.0.0.1";
let parsed = single_ip.split(',').next().unwrap().trim();
assert_eq!(parsed, "10.0.0.1");
}
}
// ============================================================================
// 性能基准测试
// ============================================================================
#[cfg(test)]
mod performance_benchmarks {
use super::security_db_tests::*;
use crate::modules::security_db::{
init_db, add_to_blacklist, is_ip_in_blacklist, get_blacklist,
clear_ip_access_logs,
};
use std::time::Instant;
/// 基准测试:黑名单查找性能
#[test]
fn benchmark_blacklist_lookup() {
let _ = init_db();
// 清理并添加 100 个黑名单条目
if let Ok(entries) = get_blacklist() {
for entry in entries {
let _ = crate::modules::security_db::remove_from_blacklist(&entry.id);
}
}
for i in 0..100 {
let _ = add_to_blacklist(
&format!("bench.ip.{}", i),
Some("Benchmark"),
None,
"test",
);
}
// 执行 1000 次查找
let start = Instant::now();
for _ in 0..1000 {
let _ = is_ip_in_blacklist("bench.ip.50");
}
let duration = start.elapsed();
println!("1000 blacklist lookups took: {:?}", duration);
println!("Average per lookup: {:?}", duration / 1000);
// 性能断言:平均查找应该在 1ms 以内
assert!(
duration.as_millis() < 5000,
"Blacklist lookup should be fast (< 5ms avg)"
);
// 清理
if let Ok(entries) = get_blacklist() {
for entry in entries {
let _ = crate::modules::security_db::remove_from_blacklist(&entry.id);
}
}
}
/// 基准测试:CIDR 匹配性能
#[test]
fn benchmark_cidr_matching() {
let _ = init_db();
// 清理并添加 CIDR 规则
if let Ok(entries) = get_blacklist() {
for entry in entries {
let _ = crate::modules::security_db::remove_from_blacklist(&entry.id);
}
}
// 添加 20 个 CIDR 规则
for i in 0..20 {
let _ = add_to_blacklist(
&format!("10.{}.0.0/16", i),
Some("CIDR Benchmark"),
None,
"test",
);
}
// 测试 CIDR 匹配性能
let start = Instant::now();
for _ in 0..1000 {
// 测试需要遍历 CIDR 的 IP
let _ = is_ip_in_blacklist("10.5.100.50");
}
let duration = start.elapsed();
println!("1000 CIDR matches took: {:?}", duration);
println!("Average per match: {:?}", duration / 1000);
// 性能断言:CIDR 匹配应该在合理时间内
assert!(
duration.as_millis() < 5000,
"CIDR matching should be reasonably fast"
);
// 清理
if let Ok(entries) = get_blacklist() {
for entry in entries {
let _ = crate::modules::security_db::remove_from_blacklist(&entry.id);
}
}
}
}