use axum::{ extract::{Request, State}, middleware::Next, response::Response, body::Body, }; use std::time::Instant; use crate::proxy::server::AppState; use crate::proxy::monitor::ProxyRequestLog; use serde_json::Value; use crate::proxy::middleware::auth::UserTokenIdentity; use futures::StreamExt; const MAX_REQUEST_LOG_SIZE: usize = 100 * 1024 * 1024; // 100MB const MAX_RESPONSE_LOG_SIZE: usize = 100 * 1024 * 1024; // 100MB for image responses /// Helper function to record User Token usage fn record_user_token_usage( user_token_identity: &Option, log: &ProxyRequestLog, user_agent: Option, ) { if let Some(identity) = user_token_identity { let _ = crate::modules::user_token_db::record_token_usage_and_ip( &identity.token_id, log.client_ip.as_deref().unwrap_or("127.0.0.1"), log.model.as_deref().unwrap_or("unknown"), log.input_tokens.unwrap_or(0) as i32, log.output_tokens.unwrap_or(0) as i32, log.status as u16, user_agent, ); } } pub async fn monitor_middleware( State(state): State, request: Request, next: Next, ) -> Response { let _logging_enabled = state.monitor.is_enabled(); let method = request.method().to_string(); let uri = request.uri().to_string(); if uri.contains("event_logging") || uri.contains("/api/") || uri.starts_with("/internal/") { return next.run(request).await; } let start = Instant::now(); // Extract client IP from headers (X-Forwarded-For or X-Real-IP) // IMPORTANT: Extract from Request headers, not Response headers (since we want the client's IP) // Note: We need to do this BEFORE consuming the request body if possible, or extract it from the original request let client_ip = request .headers() .get("x-forwarded-for") .and_then(|v| v.to_str().ok()) .map(|s| s.split(',').next().unwrap_or(s).trim().to_string()) .or_else(|| { request .headers() .get("x-real-ip") .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()) }); let user_agent = request .headers() .get("user-agent") .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()); let mut model = if uri.contains("/v1beta/models/") { uri.split("/v1beta/models/") .nth(1) .and_then(|s| s.split(':').next()) .map(|s| s.to_string()) } else { None }; let request_body_str; // [FIX] 从请求 extensions 提取 UserTokenIdentity (由 Auth 中间件注入) // 必须在处理 request body 之前提取,因为 into_parts() 后需要保留这个值 let user_token_identity = request.extensions().get::().cloned(); let request = if method == "POST" { let (parts, body) = request.into_parts(); match axum::body::to_bytes(body, MAX_REQUEST_LOG_SIZE).await { Ok(bytes) => { if model.is_none() { model = serde_json::from_slice::(&bytes).ok().and_then(|v| v.get("model").and_then(|m| m.as_str()).map(|s| s.to_string()) ); } request_body_str = if let Ok(s) = std::str::from_utf8(&bytes) { Some(s.to_string()) } else { Some("[Binary Request Data]".to_string()) }; Request::from_parts(parts, Body::from(bytes)) } Err(_) => { request_body_str = None; Request::from_parts(parts, Body::empty()) } } } else { request_body_str = None; request }; let response = next.run(request).await; // user_token_identity 已在上面从请求 extensions 中提取 let duration = start.elapsed().as_millis() as u64; let status = response.status().as_u16(); let content_type = response.headers().get("content-type") .and_then(|v| v.to_str().ok()) .unwrap_or("") .to_string(); // Extract account email from X-Account-Email header if present let account_email = response .headers() .get("X-Account-Email") .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()); // Extract mapped model from X-Mapped-Model header if present let mapped_model = response .headers() .get("X-Mapped-Model") .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()); // Determine protocol from URL path let protocol = if uri.contains("/v1/messages") { Some("anthropic".to_string()) } else if uri.contains("/v1beta/models") { Some("gemini".to_string()) } else if uri.starts_with("/v1/") { Some("openai".to_string()) } else { None }; // Client IP has been extracted at the beginning of the function // Extract username from UserTokenIdentity if present let username = user_token_identity.as_ref().map(|identity| identity.username.clone()); let monitor = state.monitor.clone(); let mut log = ProxyRequestLog { id: uuid::Uuid::new_v4().to_string(), timestamp: chrono::Utc::now().timestamp_millis(), method, url: uri, status, duration, model, mapped_model, account_email, client_ip, error: None, request_body: request_body_str, response_body: None, input_tokens: None, output_tokens: None, protocol, username, }; if content_type.contains("text/event-stream") { let (parts, body) = response.into_parts(); let mut stream = body.into_data_stream(); let (tx, rx) = tokio::sync::mpsc::channel(64); tokio::spawn(async move { let mut all_stream_data = Vec::new(); let mut last_few_bytes = Vec::new(); while let Some(chunk_res) = stream.next().await { if let Ok(chunk) = chunk_res { all_stream_data.extend_from_slice(&chunk); if chunk.len() > 8192 { last_few_bytes = chunk.slice(chunk.len()-8192..).to_vec(); } else { last_few_bytes.extend_from_slice(&chunk); if last_few_bytes.len() > 8192 { last_few_bytes.drain(0..last_few_bytes.len()-8192); } } let _ = tx.send(Ok::<_, axum::Error>(chunk)).await; } else if let Err(e) = chunk_res { let _ = tx.send(Err(axum::Error::new(e))).await; } } // Parse and consolidate stream data into readable format if let Ok(full_response) = std::str::from_utf8(&all_stream_data) { let mut thinking_content = String::new(); let mut response_content = String::new(); let mut thinking_signature = String::new(); let mut tool_calls: Vec = Vec::new(); for line in full_response.lines() { if !line.starts_with("data: ") { continue; } let json_str = line.trim_start_matches("data: ").trim(); if json_str == "[DONE]" { continue; } if let Ok(json) = serde_json::from_str::(json_str) { // OpenAI format: choices[0].delta.content / reasoning_content / tool_calls if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) { for choice in choices { if let Some(delta) = choice.get("delta") { // Thinking/reasoning content if let Some(thinking) = delta.get("reasoning_content").and_then(|v| v.as_str()) { thinking_content.push_str(thinking); } // Main response content if let Some(content) = delta.get("content").and_then(|v| v.as_str()) { response_content.push_str(content); } // Tool calls if let Some(delta_tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) { for tc in delta_tool_calls { if let Some(index) = tc.get("index").and_then(|i| i.as_u64()) { let idx = index as usize; while tool_calls.len() <= idx { tool_calls.push(serde_json::json!({ "id": "", "type": "function", "function": { "name": "", "arguments": "" } })); } let current_tc = &mut tool_calls[idx]; if let Some(id) = tc.get("id").and_then(|v| v.as_str()) { current_tc["id"] = Value::String(id.to_string()); } if let Some(func) = tc.get("function") { if let Some(name) = func.get("name").and_then(|v| v.as_str()) { current_tc["function"]["name"] = Value::String(name.to_string()); } if let Some(args) = func.get("arguments").and_then(|v| v.as_str()) { let old_args = current_tc["function"]["arguments"].as_str().unwrap_or(""); current_tc["function"]["arguments"] = Value::String(format!("{}{}", old_args, args)); } } } } } } } } // Claude/Anthropic format: content_block_start, content_block_delta, etc. let msg_type = json.get("type").and_then(|t| t.as_str()); match msg_type { Some("content_block_start") => { if let (Some(index), Some(block)) = (json.get("index").and_then(|i| i.as_u64()), json.get("content_block")) { let idx = index as usize; if block.get("type").and_then(|t| t.as_str()) == Some("tool_use") { let id = block.get("id").and_then(|v| v.as_str()).unwrap_or(""); let name = block.get("name").and_then(|v| v.as_str()).unwrap_or(""); while tool_calls.len() <= idx { tool_calls.push(Value::Null); } tool_calls[idx] = serde_json::json!({ "id": id, "type": "function", "function": { "name": name, "arguments": "" } }); } } } Some("content_block_delta") => { if let (Some(index), Some(delta)) = (json.get("index").and_then(|i| i.as_u64()), json.get("delta")) { let idx = index as usize; // Tool use input delta if let Some(delta_json) = delta.get("input_json_delta").and_then(|v| v.as_str()) { if idx < tool_calls.len() && !tool_calls[idx].is_null() { let old_args = tool_calls[idx]["function"]["arguments"].as_str().unwrap_or(""); tool_calls[idx]["function"]["arguments"] = Value::String(format!("{}{}", old_args, delta_json)); } } // Legacy/Native thinking block if let Some(thinking) = delta.get("thinking").and_then(|v| v.as_str()) { thinking_content.push_str(thinking); } // Text content if let Some(text) = delta.get("text").and_then(|v| v.as_str()) { response_content.push_str(text); } } } Some("message_delta") => { if let Some(delta) = json.get("delta") { if let Some(usage) = delta.get("usage") { if let Some(output_tokens) = usage.get("output_tokens").and_then(|v| v.as_u64()) { log.output_tokens = Some(output_tokens as u32); } } } } _ => {} } // Legacy Claude delta (for older implementations or simplified streams) if msg_type.is_none() { if let Some(delta) = json.get("delta") { // Thinking block if let Some(thinking) = delta.get("thinking").and_then(|v| v.as_str()) { thinking_content.push_str(thinking); } // Thinking signature if let Some(sig) = delta.get("signature").and_then(|v| v.as_str()) { thinking_signature = sig.to_string(); } // Text content if let Some(text) = delta.get("text").and_then(|v| v.as_str()) { response_content.push_str(text); } } } // Token usage extraction if let Some(usage) = json.get("usage") .or(json.get("usageMetadata")) .or(json.get("response").and_then(|r| r.get("usage"))) { log.input_tokens = usage.get("prompt_tokens") .or(usage.get("input_tokens")) .or(usage.get("promptTokenCount")) .and_then(|v| v.as_u64()) .map(|v| v as u32); log.output_tokens = usage.get("completion_tokens") .or(usage.get("output_tokens")) .or(usage.get("candidatesTokenCount")) .and_then(|v| v.as_u64()) .map(|v| v as u32); if log.input_tokens.is_none() && log.output_tokens.is_none() { log.output_tokens = usage.get("total_tokens") .or(usage.get("totalTokenCount")) .and_then(|v| v.as_u64()) .map(|v| v as u32); } } } } // Build consolidated response object let mut consolidated = serde_json::Map::new(); if !thinking_content.is_empty() { consolidated.insert("thinking".to_string(), Value::String(thinking_content)); } if !thinking_signature.is_empty() { consolidated.insert("thinking_signature".to_string(), Value::String(thinking_signature)); } if !response_content.is_empty() { consolidated.insert("content".to_string(), Value::String(response_content)); } if !tool_calls.is_empty() { let clean_tool_calls: Vec = tool_calls.into_iter().filter(|v| !v.is_null()).collect(); if !clean_tool_calls.is_empty() { consolidated.insert("tool_calls".to_string(), Value::Array(clean_tool_calls)); } } if let Some(input) = log.input_tokens { consolidated.insert("input_tokens".to_string(), Value::Number(input.into())); } if let Some(output) = log.output_tokens { consolidated.insert("output_tokens".to_string(), Value::Number(output.into())); } if consolidated.is_empty() { // Fallback: store raw SSE data if parsing failed log.response_body = Some(full_response.to_string()); } else { log.response_body = Some(serde_json::to_string_pretty(&Value::Object(consolidated)).unwrap_or_else(|_| full_response.to_string())); } } else { log.response_body = Some(format!("[Binary Stream Data: {} bytes]", all_stream_data.len())); } // Fallback token extraction from tail if not already extracted if log.input_tokens.is_none() && log.output_tokens.is_none() { if let Ok(full_tail) = std::str::from_utf8(&last_few_bytes) { for line in full_tail.lines().rev() { if line.starts_with("data: ") && (line.contains("\"usage\"") || line.contains("\"usageMetadata\"")) { let json_str = line.trim_start_matches("data: ").trim(); if let Ok(json) = serde_json::from_str::(json_str) { if let Some(usage) = json.get("usage") .or(json.get("usageMetadata")) .or(json.get("response").and_then(|r| r.get("usage"))) { log.input_tokens = usage.get("prompt_tokens") .or(usage.get("input_tokens")) .or(usage.get("promptTokenCount")) .and_then(|v| v.as_u64()) .map(|v| v as u32); log.output_tokens = usage.get("completion_tokens") .or(usage.get("output_tokens")) .or(usage.get("candidatesTokenCount")) .and_then(|v| v.as_u64()) .map(|v| v as u32); break; } } } } } } if log.status >= 400 { log.error = Some("Stream Error or Failed".to_string()); } // Record User Token Usage record_user_token_usage(&user_token_identity, &log, user_agent.clone()); monitor.log_request(log).await; }); Response::from_parts(parts, Body::from_stream(tokio_stream::wrappers::ReceiverStream::new(rx))) } else if content_type.contains("application/json") || content_type.contains("text/") { let (parts, body) = response.into_parts(); match axum::body::to_bytes(body, MAX_RESPONSE_LOG_SIZE).await { Ok(bytes) => { if let Ok(s) = std::str::from_utf8(&bytes) { if let Ok(json) = serde_json::from_str::(&s) { // 支持 OpenAI "usage" 或 Gemini "usageMetadata" if let Some(usage) = json.get("usage").or(json.get("usageMetadata")) { log.input_tokens = usage.get("prompt_tokens") .or(usage.get("input_tokens")) .or(usage.get("promptTokenCount")) .and_then(|v| v.as_u64()) .map(|v| v as u32); log.output_tokens = usage.get("completion_tokens") .or(usage.get("output_tokens")) .or(usage.get("candidatesTokenCount")) .and_then(|v| v.as_u64()) .map(|v| v as u32); if log.input_tokens.is_none() && log.output_tokens.is_none() { log.output_tokens = usage.get("total_tokens") .or(usage.get("totalTokenCount")) .and_then(|v| v.as_u64()) .map(|v| v as u32); } } } log.response_body = Some(s.to_string()); } else { log.response_body = Some("[Binary Response Data]".to_string()); } if log.status >= 400 { log.error = log.response_body.clone(); } // Record User Token Usage record_user_token_usage(&user_token_identity, &log, user_agent.clone()); monitor.log_request(log).await; Response::from_parts(parts, Body::from(bytes)) } Err(_) => { log.response_body = Some("[Response too large (>100MB)]".to_string()); // Record User Token Usage (even if too large) record_user_token_usage(&user_token_identity, &log, user_agent.clone()); monitor.log_request(log).await; Response::from_parts(parts, Body::empty()) } } } else { log.response_body = Some(format!("[{}]", content_type)); // Record User Token Usage record_user_token_usage(&user_token_identity, &log, user_agent); monitor.log_request(log).await; response } }