retro-sync-server / apps /api-server /src /rate_limit.rs
mike dupont
init: retro-sync API server + viewer + 71 Bach tiles + catalog
1295969
//! Per-IP sliding-window rate limiter as Axum middleware.
//!
//! Limits (per rolling 60-second window):
//! /api/auth/* β†’ 10 req/min (brute-force / challenge-grind protection)
//! /api/upload β†’ 5 req/min (large file upload rate-limit)
//! everything else β†’ 120 req/min (2 req/sec burst)
//!
//! IP resolution priority:
//! 1. X-Real-IP header (set by Replit / nginx proxy)
//! 2. first IP in X-Forwarded-For header
//! 3. "unknown" (all unknown clients share the general bucket)
//!
//! State is in-memory β€” counters reset on server restart (acceptable for
//! stateless sliding-window limits; persistent limits need Redis).
//!
//! Memory: each tracked IP costs ~72 bytes + 24 bytes Γ— requests_in_window.
//! At 120 req/min/IP and 10,000 active IPs: β‰ˆ 40 MB maximum.
//! Stale IPs are pruned when the map exceeds 50,000 entries.
use crate::AppState;
use axum::{
extract::{Request, State},
http::StatusCode,
middleware::Next,
response::Response,
};
use std::{collections::HashMap, sync::Mutex, time::Instant};
use tracing::warn;
const WINDOW_SECS: u64 = 60;
/// Three-bucket limits (req per 60s)
const GENERAL_LIMIT: usize = 120;
const AUTH_LIMIT: usize = 10;
const UPLOAD_LIMIT: usize = 5;
/// Limit applied to requests whose source IP cannot be determined.
///
/// All such requests share the key "auth:unknown", "general:unknown", etc.
/// A much tighter limit than GENERAL_LIMIT prevents an attacker (or broken
/// proxy) from exhausting the shared bucket and causing collateral DoS for
/// other unresolvable clients. Legitimate deployments should configure a
/// reverse proxy that sets X-Real-IP so this fallback is never hit.
const UNKNOWN_LIMIT_DIVISOR: usize = 10;
pub struct RateLimiter {
/// Key: `"{path_bucket}:{client_ip}"` β†’ sorted list of request instants
windows: Mutex<HashMap<String, Vec<Instant>>>,
}
impl Default for RateLimiter {
fn default() -> Self {
Self::new()
}
}
impl RateLimiter {
pub fn new() -> Self {
Self {
windows: Mutex::new(HashMap::new()),
}
}
/// Returns `true` if the request is within the limit, `false` to reject.
pub fn check(&self, key: &str, limit: usize) -> bool {
let now = Instant::now();
let window = std::time::Duration::from_secs(WINDOW_SECS);
if let Ok(mut map) = self.windows.lock() {
let times = map.entry(key.to_string()).or_default();
// Prune entries older than the window
times.retain(|&t| now.duration_since(t) < window);
if times.len() >= limit {
return false;
}
times.push(now);
// Prune stale IPs to bound memory
if map.len() > 50_000 {
map.retain(|_, v| !v.is_empty());
}
}
true
}
}
/// Validate that a string is a well-formed IPv4 or IPv6 address.
/// Rejects empty strings, hostnames, and any header-injection payloads.
fn is_valid_ip(s: &str) -> bool {
s.parse::<std::net::IpAddr>().is_ok()
}
/// Extract client IP from proxy headers, falling back to "unknown".
///
/// Header values are only trusted if they parse as a valid IP address.
/// This prevents an attacker from injecting arbitrary strings into the
/// rate-limit key by setting a crafted X-Forwarded-For or X-Real-IP header.
fn client_ip(request: &Request) -> String {
// X-Real-IP (Nginx / Replit proxy)
if let Some(v) = request.headers().get("x-real-ip") {
if let Ok(s) = v.to_str() {
let ip = s.trim();
if is_valid_ip(ip) {
return ip.to_string();
}
warn!(raw=%ip, "x-real-ip header is not a valid IP β€” ignoring");
}
}
// X-Forwarded-For: client, proxy1, proxy2 β€” take the first (leftmost)
if let Some(v) = request.headers().get("x-forwarded-for") {
if let Ok(s) = v.to_str() {
if let Some(ip) = s.split(',').next() {
let ip = ip.trim();
if is_valid_ip(ip) {
return ip.to_string();
}
warn!(raw=%ip, "x-forwarded-for first entry is not a valid IP β€” ignoring");
}
}
}
"unknown".to_string()
}
/// Classify a request path into a rate-limit bucket.
fn bucket(path: &str) -> (&'static str, usize) {
if path.starts_with("/api/auth/") {
("auth", AUTH_LIMIT)
} else if path == "/api/upload" {
("upload", UPLOAD_LIMIT)
} else {
("general", GENERAL_LIMIT)
}
}
/// Axum middleware: enforce per-IP rate limits.
pub async fn enforce(
State(state): State<AppState>,
request: Request,
next: Next,
) -> Result<Response, StatusCode> {
// Exempt health / metrics endpoints from rate limiting
let path = request.uri().path().to_string();
if path == "/health" || path == "/metrics" {
return Ok(next.run(request).await);
}
let ip = client_ip(&request);
let (bucket_name, base_limit) = bucket(&path);
// Apply a tighter cap for requests with no resolvable IP (shared bucket).
// This prevents a single unknown/misconfigured source from starving the
// shared "unknown" key and causing collateral DoS for other clients.
let limit = if ip == "unknown" {
(base_limit / UNKNOWN_LIMIT_DIVISOR).max(1)
} else {
base_limit
};
let key = format!("{bucket_name}:{ip}");
if !state.rate_limiter.check(&key, limit) {
warn!(
ip=%ip,
path=%path,
bucket=%bucket_name,
limit=%limit,
"Rate limit exceeded β€” 429"
);
return Err(StatusCode::TOO_MANY_REQUESTS);
}
Ok(next.run(request).await)
}