// API Key Authentication Middleware use axum::{ extract::Request, http::{header, StatusCode}, middleware::Next, response::{IntoResponse, Response}, }; use once_cell::sync::Lazy; use std::collections::HashSet; /// Load API keys from environment variables /// Supports multiple formats: /// - API_KEYS: comma-separated list (e.g., "key1,key2,key3") /// - API_KEY_1, API_KEY_2, API_KEY_3, ... : individual keys /// - API_KEY: single key (backward compatible) fn load_api_keys() -> HashSet { let mut keys = HashSet::new(); // Format 1: API_KEYS (comma-separated) if let Ok(keys_str) = std::env::var("API_KEYS") { for key in keys_str.split(',') { let key = key.trim(); if !key.is_empty() { keys.insert(key.to_string()); tracing::info!("Loaded API key from API_KEYS: {}...", &key[..key.len().min(8)]); } } } // Format 2: API_KEY_1, API_KEY_2, etc. for i in 1..=10 { if let Ok(key) = std::env::var(format!("API_KEY_{}", i)) { let key = key.trim().to_string(); if !key.is_empty() { tracing::info!("Loaded API key from API_KEY_{}: {}...", i, &key[..key.len().min(8)]); keys.insert(key); } } } // Format 3: Single API_KEY if let Ok(key) = std::env::var("API_KEY") { let key = key.trim().to_string(); if !key.is_empty() { tracing::info!("Loaded API key from API_KEY: {}...", &key[..key.len().min(8)]); keys.insert(key); } } if keys.is_empty() { tracing::warn!("No API keys configured! Set API_KEYS, API_KEY_1, or API_KEY environment variable."); tracing::warn!("API authentication is DISABLED - all requests will be allowed."); } else { tracing::info!("Loaded {} API key(s) for authentication", keys.len()); } keys } /// Static API keys set (loaded once at startup) static API_KEYS: Lazy> = Lazy::new(load_api_keys); /// Check if authentication is enabled fn is_auth_enabled() -> bool { !API_KEYS.is_empty() } /// Validate API key fn validate_api_key(key: &str) -> bool { API_KEYS.contains(key) } /// API Key authentication middleware for proxy endpoints pub async fn auth_middleware(request: Request, next: Next) -> Result { let path = request.uri().path(); // Log the request tracing::debug!("Request: {} {}", request.method(), request.uri()); // Skip authentication for health check endpoints if path == "/healthz" || path == "/api/health" { return Ok(next.run(request).await); } // Skip authentication for static files (frontend) if !path.starts_with("/v1") && !path.starts_with("/api") { return Ok(next.run(request).await); } // Skip authentication for /api/* management endpoints // (These should be protected by HuggingFace Spaces password) if path.starts_with("/api/") { return Ok(next.run(request).await); } // If no API keys configured, allow all requests if !is_auth_enabled() { return Ok(next.run(request).await); } // Extract API key from headers // Support: Authorization: Bearer or X-API-Key: let api_key = request .headers() .get(header::AUTHORIZATION) .and_then(|h| h.to_str().ok()) .and_then(|s| { // Support both "Bearer " and just "" s.strip_prefix("Bearer ").or(Some(s)) }) .or_else(|| { request .headers() .get("x-api-key") .and_then(|h| h.to_str().ok()) }); match api_key { Some(key) if validate_api_key(key) => { tracing::debug!("API key validated successfully"); Ok(next.run(request).await) } Some(_) => { tracing::warn!("Invalid API key provided for {}", path); Err(unauthorized_response("Invalid API key")) } None => { tracing::warn!("No API key provided for {}", path); Err(unauthorized_response("API key required. Provide via 'Authorization: Bearer ' or 'X-API-Key: ' header")) } } } /// Generate unauthorized response fn unauthorized_response(message: &str) -> Response { let body = serde_json::json!({ "error": { "message": message, "type": "authentication_error", "code": "invalid_api_key" } }); ( StatusCode::UNAUTHORIZED, [(header::CONTENT_TYPE, "application/json")], serde_json::to_string(&body).unwrap(), ) .into_response() } #[cfg(test)] mod tests { use super::*; #[test] fn test_api_key_loading() { // This test would need to be run with environment variables set assert!(true); } }