| |
| use axum::{ |
| extract::Request, |
| http::{header, StatusCode}, |
| middleware::Next, |
| response::{IntoResponse, Response}, |
| }; |
| use once_cell::sync::Lazy; |
| use std::collections::HashSet; |
|
|
| |
| |
| |
| |
| |
| fn load_api_keys() -> HashSet<String> { |
| let mut keys = HashSet::new(); |
|
|
| |
| if let Ok(keys_str) = std::env::var("API_KEYS") { |
| for key in keys_str.split(',') { |
| let key = key.trim(); |
| if !key.is_empty() { |
| keys.insert(key.to_string()); |
| tracing::info!("Loaded API key from API_KEYS: {}...", &key[..key.len().min(8)]); |
| } |
| } |
| } |
|
|
| |
| for i in 1..=10 { |
| if let Ok(key) = std::env::var(format!("API_KEY_{}", i)) { |
| let key = key.trim().to_string(); |
| if !key.is_empty() { |
| tracing::info!("Loaded API key from API_KEY_{}: {}...", i, &key[..key.len().min(8)]); |
| keys.insert(key); |
| } |
| } |
| } |
|
|
| |
| if let Ok(key) = std::env::var("API_KEY") { |
| let key = key.trim().to_string(); |
| if !key.is_empty() { |
| tracing::info!("Loaded API key from API_KEY: {}...", &key[..key.len().min(8)]); |
| keys.insert(key); |
| } |
| } |
|
|
| if keys.is_empty() { |
| tracing::warn!("No API keys configured! Set API_KEYS, API_KEY_1, or API_KEY environment variable."); |
| tracing::warn!("API authentication is DISABLED - all requests will be allowed."); |
| } else { |
| tracing::info!("Loaded {} API key(s) for authentication", keys.len()); |
| } |
|
|
| keys |
| } |
|
|
| |
| static API_KEYS: Lazy<HashSet<String>> = Lazy::new(load_api_keys); |
|
|
| |
| fn is_auth_enabled() -> bool { |
| !API_KEYS.is_empty() |
| } |
|
|
| |
| fn validate_api_key(key: &str) -> bool { |
| API_KEYS.contains(key) |
| } |
|
|
| |
| pub async fn auth_middleware(request: Request, next: Next) -> Result<Response, Response> { |
| let path = request.uri().path(); |
|
|
| |
| tracing::debug!("Request: {} {}", request.method(), request.uri()); |
|
|
| |
| if path == "/healthz" || path == "/api/health" { |
| return Ok(next.run(request).await); |
| } |
|
|
| |
| if !path.starts_with("/v1") && !path.starts_with("/api") { |
| return Ok(next.run(request).await); |
| } |
|
|
| |
| |
| if path.starts_with("/api/") { |
| return Ok(next.run(request).await); |
| } |
|
|
| |
| if !is_auth_enabled() { |
| return Ok(next.run(request).await); |
| } |
|
|
| |
| |
| let api_key = request |
| .headers() |
| .get(header::AUTHORIZATION) |
| .and_then(|h| h.to_str().ok()) |
| .and_then(|s| { |
| |
| s.strip_prefix("Bearer ").or(Some(s)) |
| }) |
| .or_else(|| { |
| request |
| .headers() |
| .get("x-api-key") |
| .and_then(|h| h.to_str().ok()) |
| }); |
|
|
| match api_key { |
| Some(key) if validate_api_key(key) => { |
| tracing::debug!("API key validated successfully"); |
| Ok(next.run(request).await) |
| } |
| Some(_) => { |
| tracing::warn!("Invalid API key provided for {}", path); |
| Err(unauthorized_response("Invalid API key")) |
| } |
| None => { |
| tracing::warn!("No API key provided for {}", path); |
| Err(unauthorized_response("API key required. Provide via 'Authorization: Bearer <key>' or 'X-API-Key: <key>' header")) |
| } |
| } |
| } |
|
|
| |
| fn unauthorized_response(message: &str) -> Response { |
| let body = serde_json::json!({ |
| "error": { |
| "message": message, |
| "type": "authentication_error", |
| "code": "invalid_api_key" |
| } |
| }); |
|
|
| ( |
| StatusCode::UNAUTHORIZED, |
| [(header::CONTENT_TYPE, "application/json")], |
| serde_json::to_string(&body).unwrap(), |
| ) |
| .into_response() |
| } |
|
|
| #[cfg(test)] |
| mod tests { |
| use super::*; |
|
|
| #[test] |
| fn test_api_key_loading() { |
| |
| assert!(true); |
| } |
| } |
|
|