use crate::interfaces::http::api::AppState; use axum::{ body::Body, extract::State, http::{HeaderValue, StatusCode}, middleware::Next, }; use once_cell::sync::Lazy; use std::sync::Arc; use std::time::{Duration, Instant}; use uuid::Uuid; use dashmap::DashMap; #[derive(Clone, Debug)] pub struct RequestId(pub String); // ─── Pre-allocated Security Headers (avoid .parse() on every request) ─── static HEADER_NOSNIFF: HeaderValue = HeaderValue::from_static("nosniff"); static HEADER_DENY: HeaderValue = HeaderValue::from_static("DENY"); static HEADER_XSS: HeaderValue = HeaderValue::from_static("1; mode=block"); static HEADER_REFERRER: HeaderValue = HeaderValue::from_static("strict-origin-when-cross-origin"); static HEADER_HSTS: HeaderValue = HeaderValue::from_static("max-age=31536000; includeSubDomains"); static HEADER_CSP: HeaderValue = HeaderValue::from_static( "default-src 'self'; script-src 'self' https://*.sentry.io https://*.razorpay.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; font-src 'self' data: https://fonts.gstatic.com; img-src 'self' data: https:; connect-src 'self' https://*.hf.space wss://*.hf.space https://*.onrender.com wss://*.onrender.com https://*.sentry.io https://*.razorpay.com; frame-src 'self' https://*.razorpay.com; frame-ancestors 'none'" ); // ─── Rate Limiter ─── #[derive(Clone)] pub struct RateLimitStore { store: Arc>>, max_requests: usize, window_secs: u64, } impl RateLimitStore { pub fn new(max_requests: usize, window_secs: u64) -> Self { let store = Arc::new(DashMap::new()); // Spawn background cleanup task to prevent memory leaks. // Runs every 10 seconds to keep memory bounded without blocking the request hot path. let cleanup_store = store.clone(); let cleanup_window = window_secs; tokio::spawn(async move { let mut interval = tokio::time::interval(Duration::from_secs(10)); loop { interval.tick().await; let now = Instant::now(); let window = Duration::from_secs(cleanup_window); // Evict expired timestamps cleanup_store.retain(|_key, timestamps: &mut Vec| { timestamps.retain(|&t| now.duration_since(t) < window); !timestamps.is_empty() }); // Fail-safe protection: if active keys still exceed 100,000, clear to prevent OOM if cleanup_store.len() > 100_000 { cleanup_store.clear(); tracing::warn!("RateLimitStore cleared automatically due to DDoS high-capacity threshold"); } } }); Self { store, max_requests, window_secs, } } pub fn is_allowed(&self, key: &str) -> bool { if std::env::var("DISABLE_RATE_LIMIT").map(|v| v == "true").unwrap_or(false) { return true; } let now = Instant::now(); let window = Duration::from_secs(self.window_secs); // Instant fail-fast block under catastrophic DDoS surges to protect memory if self.store.len() > 120_000 { return false; } let mut entry = self.store.entry(key.to_string()).or_default(); let timestamps = entry.value_mut(); timestamps.retain(|&t| now.duration_since(t) < window); if timestamps.len() < self.max_requests { timestamps.push(now); true } else { false } } } pub static RATE_LIMIT_STORE: Lazy = Lazy::new(|| RateLimitStore::new(100, 60)); use std::sync::atomic::{AtomicU64, Ordering}; pub struct ConcurrentBloomFilter { bits: [AtomicU64; 1024], } impl Default for ConcurrentBloomFilter { fn default() -> Self { Self::new() } } impl ConcurrentBloomFilter { pub fn new() -> Self { Self { bits: std::array::from_fn(|_| AtomicU64::new(0)), } } pub fn clear(&self) { for slot in &self.bits { slot.store(0, Ordering::Relaxed); } } fn hash1(ip: &str) -> u32 { let mut hash = 2166136261u32; for byte in ip.as_bytes() { hash ^= *byte as u32; hash = hash.wrapping_mul(16777619); } hash } fn hash2(ip: &str) -> u32 { let mut hash = 5381u32; for byte in ip.as_bytes() { hash = hash.wrapping_mul(33).wrapping_add(*byte as u32); } hash } pub fn insert(&self, ip: &str) { let h1 = Self::hash1(ip); let h2 = Self::hash2(ip); for i in 0..3 { let index = (h1.wrapping_add(i * h2) % 65536) as usize; let bucket = index / 64; let bit = index % 64; let mask = 1u64 << bit; self.bits[bucket].fetch_or(mask, Ordering::SeqCst); } } pub fn contains(&self, ip: &str) -> bool { let h1 = Self::hash1(ip); let h2 = Self::hash2(ip); for i in 0..3 { let index = (h1.wrapping_add(i * h2) % 65536) as usize; let bucket = index / 64; let bit = index % 64; let mask = 1u64 << bit; let value = self.bits[bucket].load(Ordering::Relaxed); if (value & mask) == 0 { return false; } } true } } pub static BLOOM_FILTER: Lazy = Lazy::new(ConcurrentBloomFilter::new); // ─── In-Memory Blocked IP Cache ─── // Refreshed every 60 seconds from the database, eliminating per-request DB queries. pub static BLOCKED_IP_CACHE: Lazy>> = Lazy::new(|| Arc::new(DashMap::new())); /// Spawns a background task that refreshes the blocked IP cache from the database. /// Call this once at startup from main.rs. pub fn spawn_blocked_ip_cache_refresher(pool: crate::infrastructure::db::DbPool) { let cache = BLOCKED_IP_CACHE.clone(); tokio::spawn(async move { let mut interval = tokio::time::interval(Duration::from_secs(60)); loop { interval.tick().await; match sqlx::query_as::<_, (String,)>( "SELECT ip FROM security_blocks WHERE expires_at > NOW()", ) .fetch_all(&pool) .await { Ok(rows) => { cache.clear(); BLOOM_FILTER.clear(); let now = Instant::now(); for (ip,) in rows { cache.insert(ip.clone(), now); BLOOM_FILTER.insert(&ip); } tracing::debug!( "Blocked IP cache & Bloom Filter refreshed: {} entries", cache.len() ); } Err(e) => { tracing::warn!("Failed to refresh blocked IP cache: {}", e); // Keep stale cache rather than clearing — safer fallback. } } } }); } pub fn get_client_ip(req: &axum::extract::Request) -> String { req.headers() .get("x-forwarded-for") .and_then(|v| v.to_str().ok()) .or_else(|| req.headers().get("x-real-ip").and_then(|v| v.to_str().ok())) .or_else(|| req.headers().get("cf-connecting-ip").and_then(|v| v.to_str().ok())) .or_else(|| req.headers().get("true-client-ip").and_then(|v| v.to_str().ok())) .and_then(|ip| ip.split(',').next()) .map(|ip| ip.trim()) .unwrap_or("0.0.0.0") .to_string() } pub fn rate_limit_response(limit_type: &str) -> axum::http::Response { let window = 60u64; let mut response = axum::http::Response::new(Body::from(format!( "Too Many Requests - Rate limit exceeded for {}. Try again in {} seconds.", limit_type, window ))); *response.status_mut() = StatusCode::TOO_MANY_REQUESTS; response.headers_mut().insert( "Retry-After", HeaderValue::from_str(&window.to_string()).unwrap_or(HeaderValue::from_static("60")), ); response } pub async fn security_middleware( State(state): State, req: axum::extract::Request, next: Next, ) -> axum::response::Response { let path = req.uri().path().to_string(); let client_ip = get_client_ip(&req); let request_id = Uuid::new_v4().to_string(); // 1. Lock-free Bloom Filter IP Block check if BLOOM_FILTER.contains(&client_ip) { // Fast-path confirm via DashMap if bloom hit occurs if BLOCKED_IP_CACHE.contains_key(&client_ip) { metrics::counter!("rtix_firewall_blocks_total", "reason" => "persistent").increment(1); tracing::error!(ip = %client_ip, "Persistent block encountered (cached & confirmed)"); let mut res = axum::response::Response::new(Body::from( "Access Denied: Your IP is persistently blocked for protocol violations.", )); *res.status_mut() = StatusCode::FORBIDDEN; return res; } } // 2. Rate Limiting Logic (in-memory, no DB) let rate_limit_type = if path.starts_with("/v1/auth/") { "auth" } else if path.starts_with("/v1/checkout/") || path.starts_with("/v1/payment/") { "strict" } else { "general" }; let rate_key = format!("{}:{}", rate_limit_type, client_ip); let allowed = RATE_LIMIT_STORE.is_allowed(&rate_key); if !allowed { metrics::counter!("rtix_firewall_blocks_total", "reason" => "rate_limit", "type" => rate_limit_type.to_string()).increment(1); tracing::warn!(ip = %client_ip, limit_type = %rate_limit_type, "Rate limit exceeded"); // Auto-persistent block if strictly violated (e.g. brute force auth) if rate_limit_type == "auth" { metrics::counter!("rtix_firewall_blocks_total", "reason" => "auto_ban").increment(1); block_ip_persistently( &state.pool, &client_ip, "Automated Ban: Repeated Auth Violations", Some(&state.tx), ) .await; } return rate_limit_response(rate_limit_type); } // 3. Request Logging & Headers let origin = req .headers() .get("origin") .and_then(|v| v.to_str().ok()) .unwrap_or("none"); tracing::info!(ip = %client_ip, method = %req.method(), path = %path, origin = %origin, "Request started"); let mut req = req; req.headers_mut().insert( "x-request-id", HeaderValue::from_str(&request_id).unwrap_or(HeaderValue::from_static("unknown")), ); // Store in extensions for services to consume req.extensions_mut().insert(RequestId(request_id.clone())); let response = next.run(req).await; let mut response = response; response.headers_mut().insert( "x-request-id", HeaderValue::from_str(&request_id).unwrap_or(HeaderValue::from_static("unknown")), ); // Security Hardening Headers (pre-allocated statics — zero allocation cost) response .headers_mut() .insert("X-Content-Type-Options", HEADER_NOSNIFF.clone()); response .headers_mut() .insert("X-Frame-Options", HEADER_DENY.clone()); response .headers_mut() .insert("X-XSS-Protection", HEADER_XSS.clone()); response .headers_mut() .insert("Referrer-Policy", HEADER_REFERRER.clone()); response .headers_mut() .insert("Strict-Transport-Security", HEADER_HSTS.clone()); response .headers_mut() .insert("Content-Security-Policy", HEADER_CSP.clone()); tracing::info!( ip = %client_ip, request_id = %request_id, status = %response.status(), "Request completed" ); response } /// Manually block an IP address persistently across restarts and sessions. /// This inserts the IP into the database and the live in-memory cache. pub async fn block_ip_persistently( pool: &crate::infrastructure::db::DbPool, ip: &str, reason: &str, tx: Option<&tokio::sync::broadcast::Sender>, ) { let _ = sqlx::query("INSERT INTO security_blocks (ip, reason, block_level, expires_at) VALUES ($1, $2, $3, NOW() + INTERVAL '24 hours') ON CONFLICT (ip) DO UPDATE SET expires_at = NOW() + INTERVAL '24 hours', reason = EXCLUDED.reason") .bind(ip) .bind(reason) .bind("BLOCK") .execute(pool) .await; BLOCKED_IP_CACHE.insert(ip.to_string(), Instant::now()); BLOOM_FILTER.insert(ip); if let Some(sender) = tx { let _ = sender.send(crate::interfaces::http::api::RealtimeEvent::SentinelBlock { ip: ip.to_string(), reason: reason.to_string(), }); } tracing::error!(ip = %ip, reason = %reason, "IP persistently blocked via automated defense"); } pub async fn api_key_auth( State(state): State, req: axum::extract::Request, next: Next, ) -> axum::response::Response { let key_header = req.headers().get("x-api-key").and_then(|v| v.to_str().ok()); let key_str = match key_header { Some(k) => k, None => { let mut res = axum::response::Response::new(Body::from("Missing X-API-Key header")); *res.status_mut() = StatusCode::UNAUTHORIZED; return res; } }; let parts: Vec<&str> = key_str.split('.').collect(); if parts.len() != 2 { let mut res = axum::response::Response::new(Body::from( "Invalid API key format. Expected 'id.secret'", )); *res.status_mut() = StatusCode::BAD_REQUEST; return res; } let key_id = parts[0]; let secret = parts[1]; let key_record = match crate::domain::models::ApiKeyRecord::find_by_id(&state.pool, key_id).await { Ok(Some(k)) => k, _ => { let mut res = axum::response::Response::new(Body::from("Invalid API key")); *res.status_mut() = StatusCode::UNAUTHORIZED; return res; } }; // Verify secret hash use argon2::{Argon2, PasswordHash, PasswordVerifier}; let parsed_hash = match PasswordHash::new(&key_record.secret_hash) { Ok(h) => h, Err(_) => { let mut res = axum::response::Response::new(Body::from("Internal security error")); *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; return res; } }; if Argon2::default() .verify_password(secret.as_bytes(), &parsed_hash) .is_err() { let mut res = axum::response::Response::new(Body::from("Invalid API secret")); *res.status_mut() = StatusCode::UNAUTHORIZED; return res; } // Add merchant_id to request extensions for later extractors let mut req = req; req.extensions_mut().insert(key_record.merchant_id); next.run(req).await }