File size: 5,769 Bytes
1295969
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
//! Per-IP sliding-window rate limiter as Axum middleware.
//!
//! Limits (per rolling 60-second window):
//!   /api/auth/*     β†’ 10 req/min  (brute-force / challenge-grind protection)
//!   /api/upload     β†’ 5  req/min  (large file upload rate-limit)
//!   everything else β†’ 120 req/min (2 req/sec burst)
//!
//! IP resolution priority:
//!   1. X-Real-IP header (set by Replit / nginx proxy)
//!   2. first IP in X-Forwarded-For header
//!   3. "unknown" (all unknown clients share the general bucket)
//!
//! State is in-memory β€” counters reset on server restart (acceptable for
//! stateless sliding-window limits; persistent limits need Redis).
//!
//! Memory: each tracked IP costs ~72 bytes + 24 bytes Γ— requests_in_window.
//! At 120 req/min/IP and 10,000 active IPs: β‰ˆ 40 MB maximum.
//! Stale IPs are pruned when the map exceeds 50,000 entries.

use crate::AppState;
use axum::{
    extract::{Request, State},
    http::StatusCode,
    middleware::Next,
    response::Response,
};
use std::{collections::HashMap, sync::Mutex, time::Instant};
use tracing::warn;

const WINDOW_SECS: u64 = 60;

/// Three-bucket limits (req per 60s)
const GENERAL_LIMIT: usize = 120;
const AUTH_LIMIT: usize = 10;
const UPLOAD_LIMIT: usize = 5;

/// Limit applied to requests whose source IP cannot be determined.
///
/// All such requests share the key "auth:unknown", "general:unknown", etc.
/// A much tighter limit than GENERAL_LIMIT prevents an attacker (or broken
/// proxy) from exhausting the shared bucket and causing collateral DoS for
/// other unresolvable clients.  Legitimate deployments should configure a
/// reverse proxy that sets X-Real-IP so this fallback is never hit.
const UNKNOWN_LIMIT_DIVISOR: usize = 10;

pub struct RateLimiter {
    /// Key: `"{path_bucket}:{client_ip}"` β†’ sorted list of request instants
    windows: Mutex<HashMap<String, Vec<Instant>>>,
}

impl Default for RateLimiter {
    fn default() -> Self {
        Self::new()
    }
}

impl RateLimiter {
    pub fn new() -> Self {
        Self {
            windows: Mutex::new(HashMap::new()),
        }
    }

    /// Returns `true` if the request is within the limit, `false` to reject.
    pub fn check(&self, key: &str, limit: usize) -> bool {
        let now = Instant::now();
        let window = std::time::Duration::from_secs(WINDOW_SECS);
        if let Ok(mut map) = self.windows.lock() {
            let times = map.entry(key.to_string()).or_default();
            // Prune entries older than the window
            times.retain(|&t| now.duration_since(t) < window);
            if times.len() >= limit {
                return false;
            }
            times.push(now);
            // Prune stale IPs to bound memory
            if map.len() > 50_000 {
                map.retain(|_, v| !v.is_empty());
            }
        }
        true
    }
}

/// Validate that a string is a well-formed IPv4 or IPv6 address.
/// Rejects empty strings, hostnames, and any header-injection payloads.
fn is_valid_ip(s: &str) -> bool {
    s.parse::<std::net::IpAddr>().is_ok()
}

/// Extract client IP from proxy headers, falling back to "unknown".
///
/// Header values are only trusted if they parse as a valid IP address.
/// This prevents an attacker from injecting arbitrary strings into the
/// rate-limit key by setting a crafted X-Forwarded-For or X-Real-IP header.
fn client_ip(request: &Request) -> String {
    // X-Real-IP (Nginx / Replit proxy)
    if let Some(v) = request.headers().get("x-real-ip") {
        if let Ok(s) = v.to_str() {
            let ip = s.trim();
            if is_valid_ip(ip) {
                return ip.to_string();
            }
            warn!(raw=%ip, "x-real-ip header is not a valid IP β€” ignoring");
        }
    }
    // X-Forwarded-For: client, proxy1, proxy2 β€” take the first (leftmost)
    if let Some(v) = request.headers().get("x-forwarded-for") {
        if let Ok(s) = v.to_str() {
            if let Some(ip) = s.split(',').next() {
                let ip = ip.trim();
                if is_valid_ip(ip) {
                    return ip.to_string();
                }
                warn!(raw=%ip, "x-forwarded-for first entry is not a valid IP β€” ignoring");
            }
        }
    }
    "unknown".to_string()
}

/// Classify a request path into a rate-limit bucket.
fn bucket(path: &str) -> (&'static str, usize) {
    if path.starts_with("/api/auth/") {
        ("auth", AUTH_LIMIT)
    } else if path == "/api/upload" {
        ("upload", UPLOAD_LIMIT)
    } else {
        ("general", GENERAL_LIMIT)
    }
}

/// Axum middleware: enforce per-IP rate limits.
pub async fn enforce(
    State(state): State<AppState>,
    request: Request,
    next: Next,
) -> Result<Response, StatusCode> {
    // Exempt health / metrics endpoints from rate limiting
    let path = request.uri().path().to_string();
    if path == "/health" || path == "/metrics" {
        return Ok(next.run(request).await);
    }

    let ip = client_ip(&request);
    let (bucket_name, base_limit) = bucket(&path);
    // Apply a tighter cap for requests with no resolvable IP (shared bucket).
    // This prevents a single unknown/misconfigured source from starving the
    // shared "unknown" key and causing collateral DoS for other clients.
    let limit = if ip == "unknown" {
        (base_limit / UNKNOWN_LIMIT_DIVISOR).max(1)
    } else {
        base_limit
    };
    let key = format!("{bucket_name}:{ip}");

    if !state.rate_limiter.check(&key, limit) {
        warn!(
            ip=%ip,
            path=%path,
            bucket=%bucket_name,
            limit=%limit,
            "Rate limit exceeded β€” 429"
        );
        return Err(StatusCode::TOO_MANY_REQUESTS);
    }

    Ok(next.run(request).await)
}