Spaces:
Running
Running
| use crate::interfaces::http::api::AppState; | |
| use axum::{ | |
| body::Body, | |
| extract::State, | |
| http::{HeaderValue, StatusCode}, | |
| middleware::Next, | |
| }; | |
| use once_cell::sync::Lazy; | |
| use std::sync::Arc; | |
| use std::time::{Duration, Instant}; | |
| use uuid::Uuid; | |
| use dashmap::DashMap; | |
| pub struct RequestId(pub String); | |
| // ─── Pre-allocated Security Headers (avoid .parse() on every request) ─── | |
| static HEADER_NOSNIFF: HeaderValue = HeaderValue::from_static("nosniff"); | |
| static HEADER_DENY: HeaderValue = HeaderValue::from_static("DENY"); | |
| static HEADER_XSS: HeaderValue = HeaderValue::from_static("1; mode=block"); | |
| static HEADER_REFERRER: HeaderValue = HeaderValue::from_static("strict-origin-when-cross-origin"); | |
| static HEADER_HSTS: HeaderValue = HeaderValue::from_static("max-age=31536000; includeSubDomains"); | |
| static HEADER_CSP: HeaderValue = HeaderValue::from_static( | |
| "default-src 'self'; script-src 'self' https://*.sentry.io https://*.razorpay.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; font-src 'self' data: https://fonts.gstatic.com; img-src 'self' data: https:; connect-src 'self' https://*.hf.space wss://*.hf.space https://*.onrender.com wss://*.onrender.com https://*.sentry.io https://*.razorpay.com; frame-src 'self' https://*.razorpay.com; frame-ancestors 'none'" | |
| ); | |
| // ─── Rate Limiter ─── | |
| pub struct RateLimitStore { | |
| store: Arc<DashMap<String, Vec<Instant>>>, | |
| max_requests: usize, | |
| window_secs: u64, | |
| } | |
| impl RateLimitStore { | |
| pub fn new(max_requests: usize, window_secs: u64) -> Self { | |
| let store = Arc::new(DashMap::new()); | |
| // Spawn background cleanup task to prevent memory leaks. | |
| // Runs every 10 seconds to keep memory bounded without blocking the request hot path. | |
| let cleanup_store = store.clone(); | |
| let cleanup_window = window_secs; | |
| tokio::spawn(async move { | |
| let mut interval = tokio::time::interval(Duration::from_secs(10)); | |
| loop { | |
| interval.tick().await; | |
| let now = Instant::now(); | |
| let window = Duration::from_secs(cleanup_window); | |
| // Evict expired timestamps | |
| cleanup_store.retain(|_key, timestamps: &mut Vec<Instant>| { | |
| timestamps.retain(|&t| now.duration_since(t) < window); | |
| !timestamps.is_empty() | |
| }); | |
| // Fail-safe protection: if active keys still exceed 100,000, clear to prevent OOM | |
| if cleanup_store.len() > 100_000 { | |
| cleanup_store.clear(); | |
| tracing::warn!("RateLimitStore cleared automatically due to DDoS high-capacity threshold"); | |
| } | |
| } | |
| }); | |
| Self { | |
| store, | |
| max_requests, | |
| window_secs, | |
| } | |
| } | |
| pub fn is_allowed(&self, key: &str) -> bool { | |
| if std::env::var("DISABLE_RATE_LIMIT").map(|v| v == "true").unwrap_or(false) { | |
| return true; | |
| } | |
| let now = Instant::now(); | |
| let window = Duration::from_secs(self.window_secs); | |
| // Instant fail-fast block under catastrophic DDoS surges to protect memory | |
| if self.store.len() > 120_000 { | |
| return false; | |
| } | |
| let mut entry = self.store.entry(key.to_string()).or_default(); | |
| let timestamps = entry.value_mut(); | |
| timestamps.retain(|&t| now.duration_since(t) < window); | |
| if timestamps.len() < self.max_requests { | |
| timestamps.push(now); | |
| true | |
| } else { | |
| false | |
| } | |
| } | |
| } | |
| pub static RATE_LIMIT_STORE: Lazy<RateLimitStore> = Lazy::new(|| RateLimitStore::new(100, 60)); | |
| use std::sync::atomic::{AtomicU64, Ordering}; | |
| pub struct ConcurrentBloomFilter { | |
| bits: [AtomicU64; 1024], | |
| } | |
| impl Default for ConcurrentBloomFilter { | |
| fn default() -> Self { | |
| Self::new() | |
| } | |
| } | |
| impl ConcurrentBloomFilter { | |
| pub fn new() -> Self { | |
| Self { | |
| bits: std::array::from_fn(|_| AtomicU64::new(0)), | |
| } | |
| } | |
| pub fn clear(&self) { | |
| for slot in &self.bits { | |
| slot.store(0, Ordering::Relaxed); | |
| } | |
| } | |
| fn hash1(ip: &str) -> u32 { | |
| let mut hash = 2166136261u32; | |
| for byte in ip.as_bytes() { | |
| hash ^= *byte as u32; | |
| hash = hash.wrapping_mul(16777619); | |
| } | |
| hash | |
| } | |
| fn hash2(ip: &str) -> u32 { | |
| let mut hash = 5381u32; | |
| for byte in ip.as_bytes() { | |
| hash = hash.wrapping_mul(33).wrapping_add(*byte as u32); | |
| } | |
| hash | |
| } | |
| pub fn insert(&self, ip: &str) { | |
| let h1 = Self::hash1(ip); | |
| let h2 = Self::hash2(ip); | |
| for i in 0..3 { | |
| let index = (h1.wrapping_add(i * h2) % 65536) as usize; | |
| let bucket = index / 64; | |
| let bit = index % 64; | |
| let mask = 1u64 << bit; | |
| self.bits[bucket].fetch_or(mask, Ordering::SeqCst); | |
| } | |
| } | |
| pub fn contains(&self, ip: &str) -> bool { | |
| let h1 = Self::hash1(ip); | |
| let h2 = Self::hash2(ip); | |
| for i in 0..3 { | |
| let index = (h1.wrapping_add(i * h2) % 65536) as usize; | |
| let bucket = index / 64; | |
| let bit = index % 64; | |
| let mask = 1u64 << bit; | |
| let value = self.bits[bucket].load(Ordering::Relaxed); | |
| if (value & mask) == 0 { | |
| return false; | |
| } | |
| } | |
| true | |
| } | |
| } | |
| pub static BLOOM_FILTER: Lazy<ConcurrentBloomFilter> = Lazy::new(ConcurrentBloomFilter::new); | |
| // ─── In-Memory Blocked IP Cache ─── | |
| // Refreshed every 60 seconds from the database, eliminating per-request DB queries. | |
| pub static BLOCKED_IP_CACHE: Lazy<Arc<DashMap<String, Instant>>> = | |
| Lazy::new(|| Arc::new(DashMap::new())); | |
| /// Spawns a background task that refreshes the blocked IP cache from the database. | |
| /// Call this once at startup from main.rs. | |
| pub fn spawn_blocked_ip_cache_refresher(pool: crate::infrastructure::db::DbPool) { | |
| let cache = BLOCKED_IP_CACHE.clone(); | |
| tokio::spawn(async move { | |
| let mut interval = tokio::time::interval(Duration::from_secs(60)); | |
| loop { | |
| interval.tick().await; | |
| match sqlx::query_as::<_, (String,)>( | |
| "SELECT ip FROM security_blocks WHERE expires_at > NOW()", | |
| ) | |
| .fetch_all(&pool) | |
| .await | |
| { | |
| Ok(rows) => { | |
| cache.clear(); | |
| BLOOM_FILTER.clear(); | |
| let now = Instant::now(); | |
| for (ip,) in rows { | |
| cache.insert(ip.clone(), now); | |
| BLOOM_FILTER.insert(&ip); | |
| } | |
| tracing::debug!( | |
| "Blocked IP cache & Bloom Filter refreshed: {} entries", | |
| cache.len() | |
| ); | |
| } | |
| Err(e) => { | |
| tracing::warn!("Failed to refresh blocked IP cache: {}", e); | |
| // Keep stale cache rather than clearing — safer fallback. | |
| } | |
| } | |
| } | |
| }); | |
| } | |
| pub fn get_client_ip(req: &axum::extract::Request) -> String { | |
| req.headers() | |
| .get("x-forwarded-for") | |
| .and_then(|v| v.to_str().ok()) | |
| .or_else(|| req.headers().get("x-real-ip").and_then(|v| v.to_str().ok())) | |
| .or_else(|| req.headers().get("cf-connecting-ip").and_then(|v| v.to_str().ok())) | |
| .or_else(|| req.headers().get("true-client-ip").and_then(|v| v.to_str().ok())) | |
| .and_then(|ip| ip.split(',').next()) | |
| .map(|ip| ip.trim()) | |
| .unwrap_or("0.0.0.0") | |
| .to_string() | |
| } | |
| pub fn rate_limit_response(limit_type: &str) -> axum::http::Response<Body> { | |
| let window = 60u64; | |
| let mut response = axum::http::Response::new(Body::from(format!( | |
| "Too Many Requests - Rate limit exceeded for {}. Try again in {} seconds.", | |
| limit_type, window | |
| ))); | |
| *response.status_mut() = StatusCode::TOO_MANY_REQUESTS; | |
| response.headers_mut().insert( | |
| "Retry-After", | |
| HeaderValue::from_str(&window.to_string()).unwrap_or(HeaderValue::from_static("60")), | |
| ); | |
| response | |
| } | |
| pub async fn security_middleware( | |
| State(state): State<AppState>, | |
| req: axum::extract::Request, | |
| next: Next, | |
| ) -> axum::response::Response { | |
| let path = req.uri().path().to_string(); | |
| let client_ip = get_client_ip(&req); | |
| let request_id = Uuid::new_v4().to_string(); | |
| // 1. Lock-free Bloom Filter IP Block check | |
| if BLOOM_FILTER.contains(&client_ip) { | |
| // Fast-path confirm via DashMap if bloom hit occurs | |
| if BLOCKED_IP_CACHE.contains_key(&client_ip) { | |
| metrics::counter!("rtix_firewall_blocks_total", "reason" => "persistent").increment(1); | |
| tracing::error!(ip = %client_ip, "Persistent block encountered (cached & confirmed)"); | |
| let mut res = axum::response::Response::new(Body::from( | |
| "Access Denied: Your IP is persistently blocked for protocol violations.", | |
| )); | |
| *res.status_mut() = StatusCode::FORBIDDEN; | |
| return res; | |
| } | |
| } | |
| // 2. Rate Limiting Logic (in-memory, no DB) | |
| let rate_limit_type = if path.starts_with("/v1/auth/") { | |
| "auth" | |
| } else if path.starts_with("/v1/checkout/") || path.starts_with("/v1/payment/") { | |
| "strict" | |
| } else { | |
| "general" | |
| }; | |
| let rate_key = format!("{}:{}", rate_limit_type, client_ip); | |
| let allowed = RATE_LIMIT_STORE.is_allowed(&rate_key); | |
| if !allowed { | |
| metrics::counter!("rtix_firewall_blocks_total", "reason" => "rate_limit", "type" => rate_limit_type.to_string()).increment(1); | |
| tracing::warn!(ip = %client_ip, limit_type = %rate_limit_type, "Rate limit exceeded"); | |
| // Auto-persistent block if strictly violated (e.g. brute force auth) | |
| if rate_limit_type == "auth" { | |
| metrics::counter!("rtix_firewall_blocks_total", "reason" => "auto_ban").increment(1); | |
| block_ip_persistently( | |
| &state.pool, | |
| &client_ip, | |
| "Automated Ban: Repeated Auth Violations", | |
| Some(&state.tx), | |
| ) | |
| .await; | |
| } | |
| return rate_limit_response(rate_limit_type); | |
| } | |
| // 3. Request Logging & Headers | |
| let origin = req | |
| .headers() | |
| .get("origin") | |
| .and_then(|v| v.to_str().ok()) | |
| .unwrap_or("none"); | |
| tracing::info!(ip = %client_ip, method = %req.method(), path = %path, origin = %origin, "Request started"); | |
| let mut req = req; | |
| req.headers_mut().insert( | |
| "x-request-id", | |
| HeaderValue::from_str(&request_id).unwrap_or(HeaderValue::from_static("unknown")), | |
| ); | |
| // Store in extensions for services to consume | |
| req.extensions_mut().insert(RequestId(request_id.clone())); | |
| let response = next.run(req).await; | |
| let mut response = response; | |
| response.headers_mut().insert( | |
| "x-request-id", | |
| HeaderValue::from_str(&request_id).unwrap_or(HeaderValue::from_static("unknown")), | |
| ); | |
| // Security Hardening Headers (pre-allocated statics — zero allocation cost) | |
| response | |
| .headers_mut() | |
| .insert("X-Content-Type-Options", HEADER_NOSNIFF.clone()); | |
| response | |
| .headers_mut() | |
| .insert("X-Frame-Options", HEADER_DENY.clone()); | |
| response | |
| .headers_mut() | |
| .insert("X-XSS-Protection", HEADER_XSS.clone()); | |
| response | |
| .headers_mut() | |
| .insert("Referrer-Policy", HEADER_REFERRER.clone()); | |
| response | |
| .headers_mut() | |
| .insert("Strict-Transport-Security", HEADER_HSTS.clone()); | |
| response | |
| .headers_mut() | |
| .insert("Content-Security-Policy", HEADER_CSP.clone()); | |
| tracing::info!( | |
| ip = %client_ip, | |
| request_id = %request_id, | |
| status = %response.status(), | |
| "Request completed" | |
| ); | |
| response | |
| } | |
| /// Manually block an IP address persistently across restarts and sessions. | |
| /// This inserts the IP into the database and the live in-memory cache. | |
| pub async fn block_ip_persistently( | |
| pool: &crate::infrastructure::db::DbPool, | |
| ip: &str, | |
| reason: &str, | |
| tx: Option<&tokio::sync::broadcast::Sender<crate::interfaces::http::api::RealtimeEvent>>, | |
| ) { | |
| let _ = sqlx::query("INSERT INTO security_blocks (ip, reason, block_level, expires_at) VALUES ($1, $2, $3, NOW() + INTERVAL '24 hours') ON CONFLICT (ip) DO UPDATE SET expires_at = NOW() + INTERVAL '24 hours', reason = EXCLUDED.reason") | |
| .bind(ip) | |
| .bind(reason) | |
| .bind("BLOCK") | |
| .execute(pool) | |
| .await; | |
| BLOCKED_IP_CACHE.insert(ip.to_string(), Instant::now()); | |
| BLOOM_FILTER.insert(ip); | |
| if let Some(sender) = tx { | |
| let _ = sender.send(crate::interfaces::http::api::RealtimeEvent::SentinelBlock { | |
| ip: ip.to_string(), | |
| reason: reason.to_string(), | |
| }); | |
| } | |
| tracing::error!(ip = %ip, reason = %reason, "IP persistently blocked via automated defense"); | |
| } | |
| pub async fn api_key_auth( | |
| State(state): State<AppState>, | |
| req: axum::extract::Request, | |
| next: Next, | |
| ) -> axum::response::Response { | |
| let key_header = req.headers().get("x-api-key").and_then(|v| v.to_str().ok()); | |
| let key_str = match key_header { | |
| Some(k) => k, | |
| None => { | |
| let mut res = axum::response::Response::new(Body::from("Missing X-API-Key header")); | |
| *res.status_mut() = StatusCode::UNAUTHORIZED; | |
| return res; | |
| } | |
| }; | |
| let parts: Vec<&str> = key_str.split('.').collect(); | |
| if parts.len() != 2 { | |
| let mut res = axum::response::Response::new(Body::from( | |
| "Invalid API key format. Expected 'id.secret'", | |
| )); | |
| *res.status_mut() = StatusCode::BAD_REQUEST; | |
| return res; | |
| } | |
| let key_id = parts[0]; | |
| let secret = parts[1]; | |
| let key_record = | |
| match crate::domain::models::ApiKeyRecord::find_by_id(&state.pool, key_id).await { | |
| Ok(Some(k)) => k, | |
| _ => { | |
| let mut res = axum::response::Response::new(Body::from("Invalid API key")); | |
| *res.status_mut() = StatusCode::UNAUTHORIZED; | |
| return res; | |
| } | |
| }; | |
| // Verify secret hash | |
| use argon2::{Argon2, PasswordHash, PasswordVerifier}; | |
| let parsed_hash = match PasswordHash::new(&key_record.secret_hash) { | |
| Ok(h) => h, | |
| Err(_) => { | |
| let mut res = axum::response::Response::new(Body::from("Internal security error")); | |
| *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; | |
| return res; | |
| } | |
| }; | |
| if Argon2::default() | |
| .verify_password(secret.as_bytes(), &parsed_hash) | |
| .is_err() | |
| { | |
| let mut res = axum::response::Response::new(Body::from("Invalid API secret")); | |
| *res.status_mut() = StatusCode::UNAUTHORIZED; | |
| return res; | |
| } | |
| // Add merchant_id to request extensions for later extractors | |
| let mut req = req; | |
| req.extensions_mut().insert(key_record.merchant_id); | |
| next.run(req).await | |
| } | |