RTIX / src /interfaces /http /middleware.rs
github-actions
deploy: clean backend production release
d8ffec9
use crate::interfaces::http::api::AppState;
use axum::{
body::Body,
extract::State,
http::{HeaderValue, StatusCode},
middleware::Next,
};
use once_cell::sync::Lazy;
use std::sync::Arc;
use std::time::{Duration, Instant};
use uuid::Uuid;
use dashmap::DashMap;
#[derive(Clone, Debug)]
pub struct RequestId(pub String);
// ─── Pre-allocated Security Headers (avoid .parse() on every request) ───
static HEADER_NOSNIFF: HeaderValue = HeaderValue::from_static("nosniff");
static HEADER_DENY: HeaderValue = HeaderValue::from_static("DENY");
static HEADER_XSS: HeaderValue = HeaderValue::from_static("1; mode=block");
static HEADER_REFERRER: HeaderValue = HeaderValue::from_static("strict-origin-when-cross-origin");
static HEADER_HSTS: HeaderValue = HeaderValue::from_static("max-age=31536000; includeSubDomains");
static HEADER_CSP: HeaderValue = HeaderValue::from_static(
"default-src 'self'; script-src 'self' https://*.sentry.io https://*.razorpay.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; font-src 'self' data: https://fonts.gstatic.com; img-src 'self' data: https:; connect-src 'self' https://*.hf.space wss://*.hf.space https://*.onrender.com wss://*.onrender.com https://*.sentry.io https://*.razorpay.com; frame-src 'self' https://*.razorpay.com; frame-ancestors 'none'"
);
// ─── Rate Limiter ───
#[derive(Clone)]
pub struct RateLimitStore {
store: Arc<DashMap<String, Vec<Instant>>>,
max_requests: usize,
window_secs: u64,
}
impl RateLimitStore {
pub fn new(max_requests: usize, window_secs: u64) -> Self {
let store = Arc::new(DashMap::new());
// Spawn background cleanup task to prevent memory leaks.
// Runs every 10 seconds to keep memory bounded without blocking the request hot path.
let cleanup_store = store.clone();
let cleanup_window = window_secs;
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(10));
loop {
interval.tick().await;
let now = Instant::now();
let window = Duration::from_secs(cleanup_window);
// Evict expired timestamps
cleanup_store.retain(|_key, timestamps: &mut Vec<Instant>| {
timestamps.retain(|&t| now.duration_since(t) < window);
!timestamps.is_empty()
});
// Fail-safe protection: if active keys still exceed 100,000, clear to prevent OOM
if cleanup_store.len() > 100_000 {
cleanup_store.clear();
tracing::warn!("RateLimitStore cleared automatically due to DDoS high-capacity threshold");
}
}
});
Self {
store,
max_requests,
window_secs,
}
}
pub fn is_allowed(&self, key: &str) -> bool {
if std::env::var("DISABLE_RATE_LIMIT").map(|v| v == "true").unwrap_or(false) {
return true;
}
let now = Instant::now();
let window = Duration::from_secs(self.window_secs);
// Instant fail-fast block under catastrophic DDoS surges to protect memory
if self.store.len() > 120_000 {
return false;
}
let mut entry = self.store.entry(key.to_string()).or_default();
let timestamps = entry.value_mut();
timestamps.retain(|&t| now.duration_since(t) < window);
if timestamps.len() < self.max_requests {
timestamps.push(now);
true
} else {
false
}
}
}
pub static RATE_LIMIT_STORE: Lazy<RateLimitStore> = Lazy::new(|| RateLimitStore::new(100, 60));
use std::sync::atomic::{AtomicU64, Ordering};
pub struct ConcurrentBloomFilter {
bits: [AtomicU64; 1024],
}
impl Default for ConcurrentBloomFilter {
fn default() -> Self {
Self::new()
}
}
impl ConcurrentBloomFilter {
pub fn new() -> Self {
Self {
bits: std::array::from_fn(|_| AtomicU64::new(0)),
}
}
pub fn clear(&self) {
for slot in &self.bits {
slot.store(0, Ordering::Relaxed);
}
}
fn hash1(ip: &str) -> u32 {
let mut hash = 2166136261u32;
for byte in ip.as_bytes() {
hash ^= *byte as u32;
hash = hash.wrapping_mul(16777619);
}
hash
}
fn hash2(ip: &str) -> u32 {
let mut hash = 5381u32;
for byte in ip.as_bytes() {
hash = hash.wrapping_mul(33).wrapping_add(*byte as u32);
}
hash
}
pub fn insert(&self, ip: &str) {
let h1 = Self::hash1(ip);
let h2 = Self::hash2(ip);
for i in 0..3 {
let index = (h1.wrapping_add(i * h2) % 65536) as usize;
let bucket = index / 64;
let bit = index % 64;
let mask = 1u64 << bit;
self.bits[bucket].fetch_or(mask, Ordering::SeqCst);
}
}
pub fn contains(&self, ip: &str) -> bool {
let h1 = Self::hash1(ip);
let h2 = Self::hash2(ip);
for i in 0..3 {
let index = (h1.wrapping_add(i * h2) % 65536) as usize;
let bucket = index / 64;
let bit = index % 64;
let mask = 1u64 << bit;
let value = self.bits[bucket].load(Ordering::Relaxed);
if (value & mask) == 0 {
return false;
}
}
true
}
}
pub static BLOOM_FILTER: Lazy<ConcurrentBloomFilter> = Lazy::new(ConcurrentBloomFilter::new);
// ─── In-Memory Blocked IP Cache ───
// Refreshed every 60 seconds from the database, eliminating per-request DB queries.
pub static BLOCKED_IP_CACHE: Lazy<Arc<DashMap<String, Instant>>> =
Lazy::new(|| Arc::new(DashMap::new()));
/// Spawns a background task that refreshes the blocked IP cache from the database.
/// Call this once at startup from main.rs.
pub fn spawn_blocked_ip_cache_refresher(pool: crate::infrastructure::db::DbPool) {
let cache = BLOCKED_IP_CACHE.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(60));
loop {
interval.tick().await;
match sqlx::query_as::<_, (String,)>(
"SELECT ip FROM security_blocks WHERE expires_at > NOW()",
)
.fetch_all(&pool)
.await
{
Ok(rows) => {
cache.clear();
BLOOM_FILTER.clear();
let now = Instant::now();
for (ip,) in rows {
cache.insert(ip.clone(), now);
BLOOM_FILTER.insert(&ip);
}
tracing::debug!(
"Blocked IP cache & Bloom Filter refreshed: {} entries",
cache.len()
);
}
Err(e) => {
tracing::warn!("Failed to refresh blocked IP cache: {}", e);
// Keep stale cache rather than clearing — safer fallback.
}
}
}
});
}
pub fn get_client_ip(req: &axum::extract::Request) -> String {
req.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.or_else(|| req.headers().get("x-real-ip").and_then(|v| v.to_str().ok()))
.or_else(|| req.headers().get("cf-connecting-ip").and_then(|v| v.to_str().ok()))
.or_else(|| req.headers().get("true-client-ip").and_then(|v| v.to_str().ok()))
.and_then(|ip| ip.split(',').next())
.map(|ip| ip.trim())
.unwrap_or("0.0.0.0")
.to_string()
}
pub fn rate_limit_response(limit_type: &str) -> axum::http::Response<Body> {
let window = 60u64;
let mut response = axum::http::Response::new(Body::from(format!(
"Too Many Requests - Rate limit exceeded for {}. Try again in {} seconds.",
limit_type, window
)));
*response.status_mut() = StatusCode::TOO_MANY_REQUESTS;
response.headers_mut().insert(
"Retry-After",
HeaderValue::from_str(&window.to_string()).unwrap_or(HeaderValue::from_static("60")),
);
response
}
pub async fn security_middleware(
State(state): State<AppState>,
req: axum::extract::Request,
next: Next,
) -> axum::response::Response {
let path = req.uri().path().to_string();
let client_ip = get_client_ip(&req);
let request_id = Uuid::new_v4().to_string();
// 1. Lock-free Bloom Filter IP Block check
if BLOOM_FILTER.contains(&client_ip) {
// Fast-path confirm via DashMap if bloom hit occurs
if BLOCKED_IP_CACHE.contains_key(&client_ip) {
metrics::counter!("rtix_firewall_blocks_total", "reason" => "persistent").increment(1);
tracing::error!(ip = %client_ip, "Persistent block encountered (cached & confirmed)");
let mut res = axum::response::Response::new(Body::from(
"Access Denied: Your IP is persistently blocked for protocol violations.",
));
*res.status_mut() = StatusCode::FORBIDDEN;
return res;
}
}
// 2. Rate Limiting Logic (in-memory, no DB)
let rate_limit_type = if path.starts_with("/v1/auth/") {
"auth"
} else if path.starts_with("/v1/checkout/") || path.starts_with("/v1/payment/") {
"strict"
} else {
"general"
};
let rate_key = format!("{}:{}", rate_limit_type, client_ip);
let allowed = RATE_LIMIT_STORE.is_allowed(&rate_key);
if !allowed {
metrics::counter!("rtix_firewall_blocks_total", "reason" => "rate_limit", "type" => rate_limit_type.to_string()).increment(1);
tracing::warn!(ip = %client_ip, limit_type = %rate_limit_type, "Rate limit exceeded");
// Auto-persistent block if strictly violated (e.g. brute force auth)
if rate_limit_type == "auth" {
metrics::counter!("rtix_firewall_blocks_total", "reason" => "auto_ban").increment(1);
block_ip_persistently(
&state.pool,
&client_ip,
"Automated Ban: Repeated Auth Violations",
Some(&state.tx),
)
.await;
}
return rate_limit_response(rate_limit_type);
}
// 3. Request Logging & Headers
let origin = req
.headers()
.get("origin")
.and_then(|v| v.to_str().ok())
.unwrap_or("none");
tracing::info!(ip = %client_ip, method = %req.method(), path = %path, origin = %origin, "Request started");
let mut req = req;
req.headers_mut().insert(
"x-request-id",
HeaderValue::from_str(&request_id).unwrap_or(HeaderValue::from_static("unknown")),
);
// Store in extensions for services to consume
req.extensions_mut().insert(RequestId(request_id.clone()));
let response = next.run(req).await;
let mut response = response;
response.headers_mut().insert(
"x-request-id",
HeaderValue::from_str(&request_id).unwrap_or(HeaderValue::from_static("unknown")),
);
// Security Hardening Headers (pre-allocated statics — zero allocation cost)
response
.headers_mut()
.insert("X-Content-Type-Options", HEADER_NOSNIFF.clone());
response
.headers_mut()
.insert("X-Frame-Options", HEADER_DENY.clone());
response
.headers_mut()
.insert("X-XSS-Protection", HEADER_XSS.clone());
response
.headers_mut()
.insert("Referrer-Policy", HEADER_REFERRER.clone());
response
.headers_mut()
.insert("Strict-Transport-Security", HEADER_HSTS.clone());
response
.headers_mut()
.insert("Content-Security-Policy", HEADER_CSP.clone());
tracing::info!(
ip = %client_ip,
request_id = %request_id,
status = %response.status(),
"Request completed"
);
response
}
/// Manually block an IP address persistently across restarts and sessions.
/// This inserts the IP into the database and the live in-memory cache.
pub async fn block_ip_persistently(
pool: &crate::infrastructure::db::DbPool,
ip: &str,
reason: &str,
tx: Option<&tokio::sync::broadcast::Sender<crate::interfaces::http::api::RealtimeEvent>>,
) {
let _ = sqlx::query("INSERT INTO security_blocks (ip, reason, block_level, expires_at) VALUES ($1, $2, $3, NOW() + INTERVAL '24 hours') ON CONFLICT (ip) DO UPDATE SET expires_at = NOW() + INTERVAL '24 hours', reason = EXCLUDED.reason")
.bind(ip)
.bind(reason)
.bind("BLOCK")
.execute(pool)
.await;
BLOCKED_IP_CACHE.insert(ip.to_string(), Instant::now());
BLOOM_FILTER.insert(ip);
if let Some(sender) = tx {
let _ = sender.send(crate::interfaces::http::api::RealtimeEvent::SentinelBlock {
ip: ip.to_string(),
reason: reason.to_string(),
});
}
tracing::error!(ip = %ip, reason = %reason, "IP persistently blocked via automated defense");
}
pub async fn api_key_auth(
State(state): State<AppState>,
req: axum::extract::Request,
next: Next,
) -> axum::response::Response {
let key_header = req.headers().get("x-api-key").and_then(|v| v.to_str().ok());
let key_str = match key_header {
Some(k) => k,
None => {
let mut res = axum::response::Response::new(Body::from("Missing X-API-Key header"));
*res.status_mut() = StatusCode::UNAUTHORIZED;
return res;
}
};
let parts: Vec<&str> = key_str.split('.').collect();
if parts.len() != 2 {
let mut res = axum::response::Response::new(Body::from(
"Invalid API key format. Expected 'id.secret'",
));
*res.status_mut() = StatusCode::BAD_REQUEST;
return res;
}
let key_id = parts[0];
let secret = parts[1];
let key_record =
match crate::domain::models::ApiKeyRecord::find_by_id(&state.pool, key_id).await {
Ok(Some(k)) => k,
_ => {
let mut res = axum::response::Response::new(Body::from("Invalid API key"));
*res.status_mut() = StatusCode::UNAUTHORIZED;
return res;
}
};
// Verify secret hash
use argon2::{Argon2, PasswordHash, PasswordVerifier};
let parsed_hash = match PasswordHash::new(&key_record.secret_hash) {
Ok(h) => h,
Err(_) => {
let mut res = axum::response::Response::new(Body::from("Internal security error"));
*res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return res;
}
};
if Argon2::default()
.verify_password(secret.as_bytes(), &parsed_hash)
.is_err()
{
let mut res = axum::response::Response::new(Body::from("Invalid API secret"));
*res.status_mut() = StatusCode::UNAUTHORIZED;
return res;
}
// Add merchant_id to request extensions for later extractors
let mut req = req;
req.extensions_mut().insert(key_record.merchant_id);
next.run(req).await
}