Spaces:
Sleeping
Sleeping
| // API Key Authentication Middleware | |
| use axum::{ | |
| extract::Request, | |
| http::{header, StatusCode}, | |
| middleware::Next, | |
| response::{IntoResponse, Response}, | |
| }; | |
| use once_cell::sync::Lazy; | |
| use std::collections::HashSet; | |
| /// Load API keys from environment variables | |
| /// Supports multiple formats: | |
| /// - API_KEYS: comma-separated list (e.g., "key1,key2,key3") | |
| /// - API_KEY_1, API_KEY_2, API_KEY_3, ... : individual keys | |
| /// - API_KEY: single key (backward compatible) | |
| fn load_api_keys() -> HashSet<String> { | |
| let mut keys = HashSet::new(); | |
| // Format 1: API_KEYS (comma-separated) | |
| if let Ok(keys_str) = std::env::var("API_KEYS") { | |
| for key in keys_str.split(',') { | |
| let key = key.trim(); | |
| if !key.is_empty() { | |
| keys.insert(key.to_string()); | |
| tracing::info!("Loaded API key from API_KEYS: {}...", &key[..key.len().min(8)]); | |
| } | |
| } | |
| } | |
| // Format 2: API_KEY_1, API_KEY_2, etc. | |
| for i in 1..=10 { | |
| if let Ok(key) = std::env::var(format!("API_KEY_{}", i)) { | |
| let key = key.trim().to_string(); | |
| if !key.is_empty() { | |
| tracing::info!("Loaded API key from API_KEY_{}: {}...", i, &key[..key.len().min(8)]); | |
| keys.insert(key); | |
| } | |
| } | |
| } | |
| // Format 3: Single API_KEY | |
| if let Ok(key) = std::env::var("API_KEY") { | |
| let key = key.trim().to_string(); | |
| if !key.is_empty() { | |
| tracing::info!("Loaded API key from API_KEY: {}...", &key[..key.len().min(8)]); | |
| keys.insert(key); | |
| } | |
| } | |
| if keys.is_empty() { | |
| tracing::warn!("No API keys configured! Set API_KEYS, API_KEY_1, or API_KEY environment variable."); | |
| tracing::warn!("API authentication is DISABLED - all requests will be allowed."); | |
| } else { | |
| tracing::info!("Loaded {} API key(s) for authentication", keys.len()); | |
| } | |
| keys | |
| } | |
| /// Static API keys set (loaded once at startup) | |
| static API_KEYS: Lazy<HashSet<String>> = Lazy::new(load_api_keys); | |
| /// Check if authentication is enabled | |
| fn is_auth_enabled() -> bool { | |
| !API_KEYS.is_empty() | |
| } | |
| /// Validate API key | |
| fn validate_api_key(key: &str) -> bool { | |
| API_KEYS.contains(key) | |
| } | |
| /// API Key authentication middleware for proxy endpoints | |
| pub async fn auth_middleware(request: Request, next: Next) -> Result<Response, Response> { | |
| let path = request.uri().path(); | |
| // Log the request | |
| tracing::debug!("Request: {} {}", request.method(), request.uri()); | |
| // Skip authentication for health check endpoints | |
| if path == "/healthz" || path == "/api/health" { | |
| return Ok(next.run(request).await); | |
| } | |
| // Skip authentication for static files (frontend) | |
| if !path.starts_with("/v1") && !path.starts_with("/api") { | |
| return Ok(next.run(request).await); | |
| } | |
| // Skip authentication for /api/* management endpoints | |
| // (These should be protected by HuggingFace Spaces password) | |
| if path.starts_with("/api/") { | |
| return Ok(next.run(request).await); | |
| } | |
| // If no API keys configured, allow all requests | |
| if !is_auth_enabled() { | |
| return Ok(next.run(request).await); | |
| } | |
| // Extract API key from headers | |
| // Support: Authorization: Bearer <key> or X-API-Key: <key> | |
| let api_key = request | |
| .headers() | |
| .get(header::AUTHORIZATION) | |
| .and_then(|h| h.to_str().ok()) | |
| .and_then(|s| { | |
| // Support both "Bearer <key>" and just "<key>" | |
| s.strip_prefix("Bearer ").or(Some(s)) | |
| }) | |
| .or_else(|| { | |
| request | |
| .headers() | |
| .get("x-api-key") | |
| .and_then(|h| h.to_str().ok()) | |
| }); | |
| match api_key { | |
| Some(key) if validate_api_key(key) => { | |
| tracing::debug!("API key validated successfully"); | |
| Ok(next.run(request).await) | |
| } | |
| Some(_) => { | |
| tracing::warn!("Invalid API key provided for {}", path); | |
| Err(unauthorized_response("Invalid API key")) | |
| } | |
| None => { | |
| tracing::warn!("No API key provided for {}", path); | |
| Err(unauthorized_response("API key required. Provide via 'Authorization: Bearer <key>' or 'X-API-Key: <key>' header")) | |
| } | |
| } | |
| } | |
| /// Generate unauthorized response | |
| fn unauthorized_response(message: &str) -> Response { | |
| let body = serde_json::json!({ | |
| "error": { | |
| "message": message, | |
| "type": "authentication_error", | |
| "code": "invalid_api_key" | |
| } | |
| }); | |
| ( | |
| StatusCode::UNAUTHORIZED, | |
| [(header::CONTENT_TYPE, "application/json")], | |
| serde_json::to_string(&body).unwrap(), | |
| ) | |
| .into_response() | |
| } | |
| mod tests { | |
| use super::*; | |
| fn test_api_key_loading() { | |
| // This test would need to be run with environment variables set | |
| assert!(true); | |
| } | |
| } | |