// API Key 认证中间件 use axum::{ extract::State, extract::Request, http::{header, StatusCode}, middleware::Next, response::Response, }; use std::sync::Arc; use tokio::sync::RwLock; use crate::proxy::{ProxyAuthMode, ProxySecurityConfig}; /// API Key 认证中间件 (代理接口使用,遵循 auth_mode) pub async fn auth_middleware( state: State>>, request: Request, next: Next, ) -> Result { auth_middleware_internal(state, request, next, false).await } /// 管理接口认证中间件 (管理接口使用,强制严格鉴权) pub async fn admin_auth_middleware( state: State>>, request: Request, next: Next, ) -> Result { auth_middleware_internal(state, request, next, true).await } /// 内部认证逻辑 async fn auth_middleware_internal( State(security): State>>, request: Request, next: Next, force_strict: bool, ) -> Result { let method = request.method().clone(); let path = request.uri().path().to_string(); // 过滤心跳和健康检查请求,避免日志噪音 let is_health_check = path == "/healthz" || path == "/api/health" || path == "/health"; let is_internal_endpoint = path.starts_with("/internal/"); if !path.contains("event_logging") && !is_health_check { tracing::info!("Request: {} {}", method, path); } else { tracing::trace!("Heartbeat/Health: {} {}", method, path); } // Allow CORS preflight regardless of auth policy. if method == axum::http::Method::OPTIONS { return Ok(next.run(request).await); } let security = security.read().await.clone(); let effective_mode = security.effective_auth_mode(); // 权限检查逻辑 if !force_strict { // AI 代理接口 (v1/chat/completions 等) if matches!(effective_mode, ProxyAuthMode::Off) { // [FIX] 即使 auth_mode=Off,也需要尝试识别 User Token 以记录使用情况 // 先检查是否携带了 User Token let api_key = request .headers() .get(header::AUTHORIZATION) .and_then(|h| h.to_str().ok()) .and_then(|s| s.strip_prefix("Bearer ").or(Some(s))) .or_else(|| { request .headers() .get("x-api-key") .and_then(|h| h.to_str().ok()) }); if let Some(token) = api_key { // 尝试验证是否为 User Token(不阻止请求,只记录) if let Ok(Some(user_token)) = crate::modules::user_token_db::get_token_by_value(token) { let identity = UserTokenIdentity { token_id: user_token.id, token: user_token.token, username: user_token.username, }; // 注入 identity 到请求 let (mut parts, body) = request.into_parts(); parts.extensions.insert(identity); let request = Request::from_parts(parts, body); return Ok(next.run(request).await); } } return Ok(next.run(request).await); } if matches!(effective_mode, ProxyAuthMode::AllExceptHealth) && is_health_check { return Ok(next.run(request).await); } // 内部端点 (/internal/*) 豁免鉴权 - 用于 warmup 等内部功能 if is_internal_endpoint { tracing::debug!("Internal endpoint bypassed auth: {}", path); return Ok(next.run(request).await); } } else { // 管理接口 (/api/*) // 1. 如果全局鉴权关闭,则管理接口也放行 (除非是强制局域网模式) if matches!(effective_mode, ProxyAuthMode::Off) { return Ok(next.run(request).await); } // 2. 健康检查在所有模式下对管理接口放行 if is_health_check { return Ok(next.run(request).await); } } // 从 header 中提取 API key let api_key = request .headers() .get(header::AUTHORIZATION) .and_then(|h| h.to_str().ok()) .and_then(|s| s.strip_prefix("Bearer ").or(Some(s))) .or_else(|| { request .headers() .get("x-api-key") .and_then(|h| h.to_str().ok()) }) .or_else(|| { request .headers() .get("x-goog-api-key") .and_then(|h| h.to_str().ok()) }); if security.api_key.is_empty() && (security.admin_password.is_none() || security.admin_password.as_ref().unwrap().is_empty()) { if force_strict { tracing::error!("Admin auth is required but both api_key and admin_password are empty; denying request"); return Err(StatusCode::UNAUTHORIZED); } tracing::error!("Proxy auth is enabled but api_key is empty; denying request"); return Err(StatusCode::UNAUTHORIZED); } // 认证逻辑 let authorized = if force_strict { // 管理接口:优先使用独立的 admin_password,如果没有则回退使用 api_key match &security.admin_password { Some(pwd) if !pwd.is_empty() => { api_key.map(|k| k == pwd).unwrap_or(false) } _ => { // 回退使用 api_key api_key.map(|k| k == security.api_key).unwrap_or(false) } } } else { // AI 代理接口:仅允许使用 api_key api_key.map(|k| k == security.api_key).unwrap_or(false) }; if authorized { Ok(next.run(request).await) } else if !force_strict && api_key.is_some() { // 尝试验证 UserToken let token = api_key.unwrap(); // 提取 IP (复用逻辑) let client_ip = request .headers() .get("x-forwarded-for") .and_then(|v| v.to_str().ok()) .map(|s| s.split(',').next().unwrap_or(s).trim().to_string()) .or_else(|| { request .headers() .get("x-real-ip") .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()) }) .unwrap_or_else(|| "127.0.0.1".to_string()); // Default fallback // 验证 Token match crate::modules::user_token_db::validate_token(token, &client_ip) { Ok((true, _)) => { // Token 有效,查询信息以便传递 if let Ok(Some(user_token)) = crate::modules::user_token_db::get_token_by_value(token) { let identity = UserTokenIdentity { token_id: user_token.id, token: user_token.token, username: user_token.username, }; // [FIX] 将身份信息注入到请求 extensions 中,而不是响应 // 这样 monitor_middleware 在处理请求时就能获取到 identity // 因为中间件执行顺序:auth (外层) -> monitor (内层) -> handler // 响应返回时:handler -> monitor -> auth // 如果注入到 response,monitor 执行时 identity 还不存在 let (mut parts, body) = request.into_parts(); parts.extensions.insert(identity); let request = Request::from_parts(parts, body); // 执行请求 let response = next.run(request).await; Ok(response) } else { Err(StatusCode::UNAUTHORIZED) } } Ok((false, reason)) => { let reason_str = reason.unwrap_or_else(|| "Access denied".to_string()); tracing::warn!("UserToken rejected: {}", reason_str); let body = serde_json::json!({ "error": { "message": reason_str, "type": "token_rejected", "code": "token_rejected" } }); let response = axum::response::Response::builder() .status(StatusCode::FORBIDDEN) .header("Content-Type", "application/json") .body(axum::body::Body::from(serde_json::to_string(&body).unwrap())) .unwrap(); Ok(response) } Err(e) => { tracing::error!("UserToken validation error: {}", e); Err(StatusCode::INTERNAL_SERVER_ERROR) } } } else { Err(StatusCode::UNAUTHORIZED) } } /// 用户令牌身份信息 (传递给 Monitor 使用) #[derive(Clone, Debug)] pub struct UserTokenIdentity { pub token_id: String, #[allow(dead_code)] // 保留原始 token 便于审计/调试 pub token: String, pub username: String, } #[cfg(test)] mod tests { use super::*; use crate::proxy::ProxyAuthMode; #[tokio::test] async fn test_admin_auth_with_password() { let security = Arc::new(RwLock::new(ProxySecurityConfig { auth_mode: ProxyAuthMode::Strict, api_key: "sk-api".to_string(), admin_password: Some("admin123".to_string()), allow_lan_access: true, port: 8045, security_monitor: crate::proxy::config::SecurityMonitorConfig::default(), })); // 模拟请求 - 管理接口使用正确的管理密码 let req = Request::builder() .header("Authorization", "Bearer admin123") .uri("/admin/stats") .body(axum::body::Body::empty()) .unwrap(); // 此测试由于涉及 Next 中间件调用比较复杂,主要验证核心逻辑 // 我们在 auth_middleware_internal 基础上做了逻辑校验即可 } #[test] fn test_auth_placeholder() { assert!(true); } }