File size: 4,936 Bytes
bbb1195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
// API Key Authentication Middleware
use axum::{
    extract::Request,
    http::{header, StatusCode},
    middleware::Next,
    response::{IntoResponse, Response},
};
use once_cell::sync::Lazy;
use std::collections::HashSet;

/// Load API keys from environment variables
/// Supports multiple formats:
/// - API_KEYS: comma-separated list (e.g., "key1,key2,key3")
/// - API_KEY_1, API_KEY_2, API_KEY_3, ... : individual keys
/// - API_KEY: single key (backward compatible)
fn load_api_keys() -> HashSet<String> {
    let mut keys = HashSet::new();

    // Format 1: API_KEYS (comma-separated)
    if let Ok(keys_str) = std::env::var("API_KEYS") {
        for key in keys_str.split(',') {
            let key = key.trim();
            if !key.is_empty() {
                keys.insert(key.to_string());
                tracing::info!("Loaded API key from API_KEYS: {}...", &key[..key.len().min(8)]);
            }
        }
    }

    // Format 2: API_KEY_1, API_KEY_2, etc.
    for i in 1..=10 {
        if let Ok(key) = std::env::var(format!("API_KEY_{}", i)) {
            let key = key.trim().to_string();
            if !key.is_empty() {
                tracing::info!("Loaded API key from API_KEY_{}: {}...", i, &key[..key.len().min(8)]);
                keys.insert(key);
            }
        }
    }

    // Format 3: Single API_KEY
    if let Ok(key) = std::env::var("API_KEY") {
        let key = key.trim().to_string();
        if !key.is_empty() {
            tracing::info!("Loaded API key from API_KEY: {}...", &key[..key.len().min(8)]);
            keys.insert(key);
        }
    }

    if keys.is_empty() {
        tracing::warn!("No API keys configured! Set API_KEYS, API_KEY_1, or API_KEY environment variable.");
        tracing::warn!("API authentication is DISABLED - all requests will be allowed.");
    } else {
        tracing::info!("Loaded {} API key(s) for authentication", keys.len());
    }

    keys
}

/// Static API keys set (loaded once at startup)
static API_KEYS: Lazy<HashSet<String>> = Lazy::new(load_api_keys);

/// Check if authentication is enabled
fn is_auth_enabled() -> bool {
    !API_KEYS.is_empty()
}

/// Validate API key
fn validate_api_key(key: &str) -> bool {
    API_KEYS.contains(key)
}

/// API Key authentication middleware for proxy endpoints
pub async fn auth_middleware(request: Request, next: Next) -> Result<Response, Response> {
    let path = request.uri().path();

    // Log the request
    tracing::debug!("Request: {} {}", request.method(), request.uri());

    // Skip authentication for health check endpoints
    if path == "/healthz" || path == "/api/health" {
        return Ok(next.run(request).await);
    }

    // Skip authentication for static files (frontend)
    if !path.starts_with("/v1") && !path.starts_with("/api") {
        return Ok(next.run(request).await);
    }

    // Skip authentication for /api/* management endpoints
    // (These should be protected by HuggingFace Spaces password)
    if path.starts_with("/api/") {
        return Ok(next.run(request).await);
    }

    // If no API keys configured, allow all requests
    if !is_auth_enabled() {
        return Ok(next.run(request).await);
    }

    // Extract API key from headers
    // Support: Authorization: Bearer <key> or X-API-Key: <key>
    let api_key = request
        .headers()
        .get(header::AUTHORIZATION)
        .and_then(|h| h.to_str().ok())
        .and_then(|s| {
            // Support both "Bearer <key>" and just "<key>"
            s.strip_prefix("Bearer ").or(Some(s))
        })
        .or_else(|| {
            request
                .headers()
                .get("x-api-key")
                .and_then(|h| h.to_str().ok())
        });

    match api_key {
        Some(key) if validate_api_key(key) => {
            tracing::debug!("API key validated successfully");
            Ok(next.run(request).await)
        }
        Some(_) => {
            tracing::warn!("Invalid API key provided for {}", path);
            Err(unauthorized_response("Invalid API key"))
        }
        None => {
            tracing::warn!("No API key provided for {}", path);
            Err(unauthorized_response("API key required. Provide via 'Authorization: Bearer <key>' or 'X-API-Key: <key>' header"))
        }
    }
}

/// Generate unauthorized response
fn unauthorized_response(message: &str) -> Response {
    let body = serde_json::json!({
        "error": {
            "message": message,
            "type": "authentication_error",
            "code": "invalid_api_key"
        }
    });

    (
        StatusCode::UNAUTHORIZED,
        [(header::CONTENT_TYPE, "application/json")],
        serde_json::to_string(&body).unwrap(),
    )
        .into_response()
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_api_key_loading() {
        // This test would need to be run with environment variables set
        assert!(true);
    }
}