File size: 4,936 Bytes
bbb1195 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | // API Key Authentication Middleware
use axum::{
extract::Request,
http::{header, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use once_cell::sync::Lazy;
use std::collections::HashSet;
/// Load API keys from environment variables
/// Supports multiple formats:
/// - API_KEYS: comma-separated list (e.g., "key1,key2,key3")
/// - API_KEY_1, API_KEY_2, API_KEY_3, ... : individual keys
/// - API_KEY: single key (backward compatible)
fn load_api_keys() -> HashSet<String> {
let mut keys = HashSet::new();
// Format 1: API_KEYS (comma-separated)
if let Ok(keys_str) = std::env::var("API_KEYS") {
for key in keys_str.split(',') {
let key = key.trim();
if !key.is_empty() {
keys.insert(key.to_string());
tracing::info!("Loaded API key from API_KEYS: {}...", &key[..key.len().min(8)]);
}
}
}
// Format 2: API_KEY_1, API_KEY_2, etc.
for i in 1..=10 {
if let Ok(key) = std::env::var(format!("API_KEY_{}", i)) {
let key = key.trim().to_string();
if !key.is_empty() {
tracing::info!("Loaded API key from API_KEY_{}: {}...", i, &key[..key.len().min(8)]);
keys.insert(key);
}
}
}
// Format 3: Single API_KEY
if let Ok(key) = std::env::var("API_KEY") {
let key = key.trim().to_string();
if !key.is_empty() {
tracing::info!("Loaded API key from API_KEY: {}...", &key[..key.len().min(8)]);
keys.insert(key);
}
}
if keys.is_empty() {
tracing::warn!("No API keys configured! Set API_KEYS, API_KEY_1, or API_KEY environment variable.");
tracing::warn!("API authentication is DISABLED - all requests will be allowed.");
} else {
tracing::info!("Loaded {} API key(s) for authentication", keys.len());
}
keys
}
/// Static API keys set (loaded once at startup)
static API_KEYS: Lazy<HashSet<String>> = Lazy::new(load_api_keys);
/// Check if authentication is enabled
fn is_auth_enabled() -> bool {
!API_KEYS.is_empty()
}
/// Validate API key
fn validate_api_key(key: &str) -> bool {
API_KEYS.contains(key)
}
/// API Key authentication middleware for proxy endpoints
pub async fn auth_middleware(request: Request, next: Next) -> Result<Response, Response> {
let path = request.uri().path();
// Log the request
tracing::debug!("Request: {} {}", request.method(), request.uri());
// Skip authentication for health check endpoints
if path == "/healthz" || path == "/api/health" {
return Ok(next.run(request).await);
}
// Skip authentication for static files (frontend)
if !path.starts_with("/v1") && !path.starts_with("/api") {
return Ok(next.run(request).await);
}
// Skip authentication for /api/* management endpoints
// (These should be protected by HuggingFace Spaces password)
if path.starts_with("/api/") {
return Ok(next.run(request).await);
}
// If no API keys configured, allow all requests
if !is_auth_enabled() {
return Ok(next.run(request).await);
}
// Extract API key from headers
// Support: Authorization: Bearer <key> or X-API-Key: <key>
let api_key = request
.headers()
.get(header::AUTHORIZATION)
.and_then(|h| h.to_str().ok())
.and_then(|s| {
// Support both "Bearer <key>" and just "<key>"
s.strip_prefix("Bearer ").or(Some(s))
})
.or_else(|| {
request
.headers()
.get("x-api-key")
.and_then(|h| h.to_str().ok())
});
match api_key {
Some(key) if validate_api_key(key) => {
tracing::debug!("API key validated successfully");
Ok(next.run(request).await)
}
Some(_) => {
tracing::warn!("Invalid API key provided for {}", path);
Err(unauthorized_response("Invalid API key"))
}
None => {
tracing::warn!("No API key provided for {}", path);
Err(unauthorized_response("API key required. Provide via 'Authorization: Bearer <key>' or 'X-API-Key: <key>' header"))
}
}
}
/// Generate unauthorized response
fn unauthorized_response(message: &str) -> Response {
let body = serde_json::json!({
"error": {
"message": message,
"type": "authentication_error",
"code": "invalid_api_key"
}
});
(
StatusCode::UNAUTHORIZED,
[(header::CONTENT_TYPE, "application/json")],
serde_json::to_string(&body).unwrap(),
)
.into_response()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_api_key_loading() {
// This test would need to be run with environment variables set
assert!(true);
}
}
|