Spaces:
Build error
Build error
File size: 5,769 Bytes
1295969 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | //! Per-IP sliding-window rate limiter as Axum middleware.
//!
//! Limits (per rolling 60-second window):
//! /api/auth/* β 10 req/min (brute-force / challenge-grind protection)
//! /api/upload β 5 req/min (large file upload rate-limit)
//! everything else β 120 req/min (2 req/sec burst)
//!
//! IP resolution priority:
//! 1. X-Real-IP header (set by Replit / nginx proxy)
//! 2. first IP in X-Forwarded-For header
//! 3. "unknown" (all unknown clients share the general bucket)
//!
//! State is in-memory β counters reset on server restart (acceptable for
//! stateless sliding-window limits; persistent limits need Redis).
//!
//! Memory: each tracked IP costs ~72 bytes + 24 bytes Γ requests_in_window.
//! At 120 req/min/IP and 10,000 active IPs: β 40 MB maximum.
//! Stale IPs are pruned when the map exceeds 50,000 entries.
use crate::AppState;
use axum::{
extract::{Request, State},
http::StatusCode,
middleware::Next,
response::Response,
};
use std::{collections::HashMap, sync::Mutex, time::Instant};
use tracing::warn;
const WINDOW_SECS: u64 = 60;
/// Three-bucket limits (req per 60s)
const GENERAL_LIMIT: usize = 120;
const AUTH_LIMIT: usize = 10;
const UPLOAD_LIMIT: usize = 5;
/// Limit applied to requests whose source IP cannot be determined.
///
/// All such requests share the key "auth:unknown", "general:unknown", etc.
/// A much tighter limit than GENERAL_LIMIT prevents an attacker (or broken
/// proxy) from exhausting the shared bucket and causing collateral DoS for
/// other unresolvable clients. Legitimate deployments should configure a
/// reverse proxy that sets X-Real-IP so this fallback is never hit.
const UNKNOWN_LIMIT_DIVISOR: usize = 10;
pub struct RateLimiter {
/// Key: `"{path_bucket}:{client_ip}"` β sorted list of request instants
windows: Mutex<HashMap<String, Vec<Instant>>>,
}
impl Default for RateLimiter {
fn default() -> Self {
Self::new()
}
}
impl RateLimiter {
pub fn new() -> Self {
Self {
windows: Mutex::new(HashMap::new()),
}
}
/// Returns `true` if the request is within the limit, `false` to reject.
pub fn check(&self, key: &str, limit: usize) -> bool {
let now = Instant::now();
let window = std::time::Duration::from_secs(WINDOW_SECS);
if let Ok(mut map) = self.windows.lock() {
let times = map.entry(key.to_string()).or_default();
// Prune entries older than the window
times.retain(|&t| now.duration_since(t) < window);
if times.len() >= limit {
return false;
}
times.push(now);
// Prune stale IPs to bound memory
if map.len() > 50_000 {
map.retain(|_, v| !v.is_empty());
}
}
true
}
}
/// Validate that a string is a well-formed IPv4 or IPv6 address.
/// Rejects empty strings, hostnames, and any header-injection payloads.
fn is_valid_ip(s: &str) -> bool {
s.parse::<std::net::IpAddr>().is_ok()
}
/// Extract client IP from proxy headers, falling back to "unknown".
///
/// Header values are only trusted if they parse as a valid IP address.
/// This prevents an attacker from injecting arbitrary strings into the
/// rate-limit key by setting a crafted X-Forwarded-For or X-Real-IP header.
fn client_ip(request: &Request) -> String {
// X-Real-IP (Nginx / Replit proxy)
if let Some(v) = request.headers().get("x-real-ip") {
if let Ok(s) = v.to_str() {
let ip = s.trim();
if is_valid_ip(ip) {
return ip.to_string();
}
warn!(raw=%ip, "x-real-ip header is not a valid IP β ignoring");
}
}
// X-Forwarded-For: client, proxy1, proxy2 β take the first (leftmost)
if let Some(v) = request.headers().get("x-forwarded-for") {
if let Ok(s) = v.to_str() {
if let Some(ip) = s.split(',').next() {
let ip = ip.trim();
if is_valid_ip(ip) {
return ip.to_string();
}
warn!(raw=%ip, "x-forwarded-for first entry is not a valid IP β ignoring");
}
}
}
"unknown".to_string()
}
/// Classify a request path into a rate-limit bucket.
fn bucket(path: &str) -> (&'static str, usize) {
if path.starts_with("/api/auth/") {
("auth", AUTH_LIMIT)
} else if path == "/api/upload" {
("upload", UPLOAD_LIMIT)
} else {
("general", GENERAL_LIMIT)
}
}
/// Axum middleware: enforce per-IP rate limits.
pub async fn enforce(
State(state): State<AppState>,
request: Request,
next: Next,
) -> Result<Response, StatusCode> {
// Exempt health / metrics endpoints from rate limiting
let path = request.uri().path().to_string();
if path == "/health" || path == "/metrics" {
return Ok(next.run(request).await);
}
let ip = client_ip(&request);
let (bucket_name, base_limit) = bucket(&path);
// Apply a tighter cap for requests with no resolvable IP (shared bucket).
// This prevents a single unknown/misconfigured source from starving the
// shared "unknown" key and causing collateral DoS for other clients.
let limit = if ip == "unknown" {
(base_limit / UNKNOWN_LIMIT_DIVISOR).max(1)
} else {
base_limit
};
let key = format!("{bucket_name}:{ip}");
if !state.rate_limiter.check(&key, limit) {
warn!(
ip=%ip,
path=%path,
bucket=%bucket_name,
limit=%limit,
"Rate limit exceeded β 429"
);
return Err(StatusCode::TOO_MANY_REQUESTS);
}
Ok(next.run(request).await)
}
|