File size: 10,462 Bytes
a21c316 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 | // API Key 认证中间件
use axum::{
extract::State,
extract::Request,
http::{header, StatusCode},
middleware::Next,
response::Response,
};
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::proxy::{ProxyAuthMode, ProxySecurityConfig};
/// API Key 认证中间件 (代理接口使用,遵循 auth_mode)
pub async fn auth_middleware(
state: State<Arc<RwLock<ProxySecurityConfig>>>,
request: Request,
next: Next,
) -> Result<Response, StatusCode> {
auth_middleware_internal(state, request, next, false).await
}
/// 管理接口认证中间件 (管理接口使用,强制严格鉴权)
pub async fn admin_auth_middleware(
state: State<Arc<RwLock<ProxySecurityConfig>>>,
request: Request,
next: Next,
) -> Result<Response, StatusCode> {
auth_middleware_internal(state, request, next, true).await
}
/// 内部认证逻辑
async fn auth_middleware_internal(
State(security): State<Arc<RwLock<ProxySecurityConfig>>>,
request: Request,
next: Next,
force_strict: bool,
) -> Result<Response, StatusCode> {
let method = request.method().clone();
let path = request.uri().path().to_string();
// 过滤心跳和健康检查请求,避免日志噪音
let is_health_check = path == "/healthz" || path == "/api/health" || path == "/health";
let is_internal_endpoint = path.starts_with("/internal/");
if !path.contains("event_logging") && !is_health_check {
tracing::info!("Request: {} {}", method, path);
} else {
tracing::trace!("Heartbeat/Health: {} {}", method, path);
}
// Allow CORS preflight regardless of auth policy.
if method == axum::http::Method::OPTIONS {
return Ok(next.run(request).await);
}
let security = security.read().await.clone();
let effective_mode = security.effective_auth_mode();
// 权限检查逻辑
if !force_strict {
// AI 代理接口 (v1/chat/completions 等)
if matches!(effective_mode, ProxyAuthMode::Off) {
// [FIX] 即使 auth_mode=Off,也需要尝试识别 User Token 以记录使用情况
// 先检查是否携带了 User Token
let api_key = request
.headers()
.get(header::AUTHORIZATION)
.and_then(|h| h.to_str().ok())
.and_then(|s| s.strip_prefix("Bearer ").or(Some(s)))
.or_else(|| {
request
.headers()
.get("x-api-key")
.and_then(|h| h.to_str().ok())
});
if let Some(token) = api_key {
// 尝试验证是否为 User Token(不阻止请求,只记录)
if let Ok(Some(user_token)) = crate::modules::user_token_db::get_token_by_value(token) {
let identity = UserTokenIdentity {
token_id: user_token.id,
token: user_token.token,
username: user_token.username,
};
// 注入 identity 到请求
let (mut parts, body) = request.into_parts();
parts.extensions.insert(identity);
let request = Request::from_parts(parts, body);
return Ok(next.run(request).await);
}
}
return Ok(next.run(request).await);
}
if matches!(effective_mode, ProxyAuthMode::AllExceptHealth) && is_health_check {
return Ok(next.run(request).await);
}
// 内部端点 (/internal/*) 豁免鉴权 - 用于 warmup 等内部功能
if is_internal_endpoint {
tracing::debug!("Internal endpoint bypassed auth: {}", path);
return Ok(next.run(request).await);
}
} else {
// 管理接口 (/api/*)
// 1. 如果全局鉴权关闭,则管理接口也放行 (除非是强制局域网模式)
if matches!(effective_mode, ProxyAuthMode::Off) {
return Ok(next.run(request).await);
}
// 2. 健康检查在所有模式下对管理接口放行
if is_health_check {
return Ok(next.run(request).await);
}
}
// 从 header 中提取 API key
let api_key = request
.headers()
.get(header::AUTHORIZATION)
.and_then(|h| h.to_str().ok())
.and_then(|s| s.strip_prefix("Bearer ").or(Some(s)))
.or_else(|| {
request
.headers()
.get("x-api-key")
.and_then(|h| h.to_str().ok())
})
.or_else(|| {
request
.headers()
.get("x-goog-api-key")
.and_then(|h| h.to_str().ok())
});
if security.api_key.is_empty() && (security.admin_password.is_none() || security.admin_password.as_ref().unwrap().is_empty()) {
if force_strict {
tracing::error!("Admin auth is required but both api_key and admin_password are empty; denying request");
return Err(StatusCode::UNAUTHORIZED);
}
tracing::error!("Proxy auth is enabled but api_key is empty; denying request");
return Err(StatusCode::UNAUTHORIZED);
}
// 认证逻辑
let authorized = if force_strict {
// 管理接口:优先使用独立的 admin_password,如果没有则回退使用 api_key
match &security.admin_password {
Some(pwd) if !pwd.is_empty() => {
api_key.map(|k| k == pwd).unwrap_or(false)
}
_ => {
// 回退使用 api_key
api_key.map(|k| k == security.api_key).unwrap_or(false)
}
}
} else {
// AI 代理接口:仅允许使用 api_key
api_key.map(|k| k == security.api_key).unwrap_or(false)
};
if authorized {
Ok(next.run(request).await)
} else if !force_strict && api_key.is_some() {
// 尝试验证 UserToken
let token = api_key.unwrap();
// 提取 IP (复用逻辑)
let client_ip = request
.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.map(|s| s.split(',').next().unwrap_or(s).trim().to_string())
.or_else(|| {
request
.headers()
.get("x-real-ip")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
})
.unwrap_or_else(|| "127.0.0.1".to_string()); // Default fallback
// 验证 Token
match crate::modules::user_token_db::validate_token(token, &client_ip) {
Ok((true, _)) => {
// Token 有效,查询信息以便传递
if let Ok(Some(user_token)) = crate::modules::user_token_db::get_token_by_value(token) {
let identity = UserTokenIdentity {
token_id: user_token.id,
token: user_token.token,
username: user_token.username,
};
// [FIX] 将身份信息注入到请求 extensions 中,而不是响应
// 这样 monitor_middleware 在处理请求时就能获取到 identity
// 因为中间件执行顺序:auth (外层) -> monitor (内层) -> handler
// 响应返回时:handler -> monitor -> auth
// 如果注入到 response,monitor 执行时 identity 还不存在
let (mut parts, body) = request.into_parts();
parts.extensions.insert(identity);
let request = Request::from_parts(parts, body);
// 执行请求
let response = next.run(request).await;
Ok(response)
} else {
Err(StatusCode::UNAUTHORIZED)
}
}
Ok((false, reason)) => {
let reason_str = reason.unwrap_or_else(|| "Access denied".to_string());
tracing::warn!("UserToken rejected: {}", reason_str);
let body = serde_json::json!({
"error": {
"message": reason_str,
"type": "token_rejected",
"code": "token_rejected"
}
});
let response = axum::response::Response::builder()
.status(StatusCode::FORBIDDEN)
.header("Content-Type", "application/json")
.body(axum::body::Body::from(serde_json::to_string(&body).unwrap()))
.unwrap();
Ok(response)
}
Err(e) => {
tracing::error!("UserToken validation error: {}", e);
Err(StatusCode::INTERNAL_SERVER_ERROR)
}
}
} else {
Err(StatusCode::UNAUTHORIZED)
}
}
/// 用户令牌身份信息 (传递给 Monitor 使用)
#[derive(Clone, Debug)]
pub struct UserTokenIdentity {
pub token_id: String,
#[allow(dead_code)] // 保留原始 token 便于审计/调试
pub token: String,
pub username: String,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::proxy::ProxyAuthMode;
#[tokio::test]
async fn test_admin_auth_with_password() {
let security = Arc::new(RwLock::new(ProxySecurityConfig {
auth_mode: ProxyAuthMode::Strict,
api_key: "sk-api".to_string(),
admin_password: Some("admin123".to_string()),
allow_lan_access: true,
port: 8045,
security_monitor: crate::proxy::config::SecurityMonitorConfig::default(),
}));
// 模拟请求 - 管理接口使用正确的管理密码
let req = Request::builder()
.header("Authorization", "Bearer admin123")
.uri("/admin/stats")
.body(axum::body::Body::empty())
.unwrap();
// 此测试由于涉及 Next 中间件调用比较复杂,主要验证核心逻辑
// 我们在 auth_middleware_internal 基础上做了逻辑校验即可
}
#[test]
fn test_auth_placeholder() {
assert!(true);
}
}
|