hank9999 commited on
Commit ·
b8f39ea
1
Parent(s): c00d86d
fix(stream): 使用 contextUsageEvent 计算实际输入 tokens,调整事件生成逻辑
Browse files- src/anthropic/handlers.rs +20 -1
- src/anthropic/stream.rs +26 -4
src/anthropic/handlers.rs
CHANGED
|
@@ -306,6 +306,9 @@ fn create_sse_stream(
|
|
| 306 |
initial_stream.chain(processing_stream)
|
| 307 |
}
|
| 308 |
|
|
|
|
|
|
|
|
|
|
| 309 |
/// 处理非流式请求
|
| 310 |
async fn handle_non_stream_request(
|
| 311 |
provider: std::sync::Arc<tokio::sync::Mutex<crate::kiro::provider::KiroProvider>>,
|
|
@@ -356,6 +359,8 @@ async fn handle_non_stream_request(
|
|
| 356 |
let mut tool_uses: Vec<serde_json::Value> = Vec::new();
|
| 357 |
let mut has_tool_use = false;
|
| 358 |
let mut stop_reason = "end_turn".to_string();
|
|
|
|
|
|
|
| 359 |
|
| 360 |
// 收集工具调用的增量 JSON
|
| 361 |
let mut tool_json_buffers: std::collections::HashMap<String, String> = std::collections::HashMap::new();
|
|
@@ -396,6 +401,17 @@ async fn handle_non_stream_request(
|
|
| 396 |
}));
|
| 397 |
}
|
| 398 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 399 |
Event::Exception { exception_type, .. } => {
|
| 400 |
if exception_type == "ContentLengthExceededException" {
|
| 401 |
stop_reason = "max_tokens".to_string();
|
|
@@ -431,6 +447,9 @@ async fn handle_non_stream_request(
|
|
| 431 |
// 估算输出 tokens
|
| 432 |
let output_tokens = token::estimate_output_tokens(&content);
|
| 433 |
|
|
|
|
|
|
|
|
|
|
| 434 |
// 构建 Anthropic 响应
|
| 435 |
let response_body = json!({
|
| 436 |
"id": format!("msg_{}", Uuid::new_v4().to_string().replace('-', "")),
|
|
@@ -441,7 +460,7 @@ async fn handle_non_stream_request(
|
|
| 441 |
"stop_reason": stop_reason,
|
| 442 |
"stop_sequence": null,
|
| 443 |
"usage": {
|
| 444 |
-
"input_tokens":
|
| 445 |
"output_tokens": output_tokens
|
| 446 |
}
|
| 447 |
});
|
|
|
|
| 306 |
initial_stream.chain(processing_stream)
|
| 307 |
}
|
| 308 |
|
| 309 |
+
/// 上下文窗口大小(200k tokens)
|
| 310 |
+
const CONTEXT_WINDOW_SIZE: i32 = 200_000;
|
| 311 |
+
|
| 312 |
/// 处理非流式请求
|
| 313 |
async fn handle_non_stream_request(
|
| 314 |
provider: std::sync::Arc<tokio::sync::Mutex<crate::kiro::provider::KiroProvider>>,
|
|
|
|
| 359 |
let mut tool_uses: Vec<serde_json::Value> = Vec::new();
|
| 360 |
let mut has_tool_use = false;
|
| 361 |
let mut stop_reason = "end_turn".to_string();
|
| 362 |
+
// 从 contextUsageEvent 计算的实际输入 tokens
|
| 363 |
+
let mut context_input_tokens: Option<i32> = None;
|
| 364 |
|
| 365 |
// 收集工具调用的增量 JSON
|
| 366 |
let mut tool_json_buffers: std::collections::HashMap<String, String> = std::collections::HashMap::new();
|
|
|
|
| 401 |
}));
|
| 402 |
}
|
| 403 |
}
|
| 404 |
+
Event::ContextUsage(context_usage) => {
|
| 405 |
+
// 从上下文使用百分比计算实际的 input_tokens
|
| 406 |
+
// 公式: percentage * 200000 / 100 = percentage * 2000
|
| 407 |
+
let actual_input_tokens = (context_usage.context_usage_percentage * (CONTEXT_WINDOW_SIZE as f64) / 100.0) as i32;
|
| 408 |
+
context_input_tokens = Some(actual_input_tokens);
|
| 409 |
+
tracing::debug!(
|
| 410 |
+
"收到 contextUsageEvent: {}%, 计算 input_tokens: {}",
|
| 411 |
+
context_usage.context_usage_percentage,
|
| 412 |
+
actual_input_tokens
|
| 413 |
+
);
|
| 414 |
+
}
|
| 415 |
Event::Exception { exception_type, .. } => {
|
| 416 |
if exception_type == "ContentLengthExceededException" {
|
| 417 |
stop_reason = "max_tokens".to_string();
|
|
|
|
| 447 |
// 估算输出 tokens
|
| 448 |
let output_tokens = token::estimate_output_tokens(&content);
|
| 449 |
|
| 450 |
+
// 使用从 contextUsageEvent 计算的 input_tokens,如果没有则使用估算值
|
| 451 |
+
let final_input_tokens = context_input_tokens.unwrap_or(input_tokens);
|
| 452 |
+
|
| 453 |
// 构建 Anthropic 响应
|
| 454 |
let response_body = json!({
|
| 455 |
"id": format!("msg_{}", Uuid::new_v4().to_string().replace('-', "")),
|
|
|
|
| 460 |
"stop_reason": stop_reason,
|
| 461 |
"stop_sequence": null,
|
| 462 |
"usage": {
|
| 463 |
+
"input_tokens": final_input_tokens,
|
| 464 |
"output_tokens": output_tokens
|
| 465 |
}
|
| 466 |
});
|
src/anthropic/stream.rs
CHANGED
|
@@ -351,7 +351,7 @@ impl SseStateManager {
|
|
| 351 |
}
|
| 352 |
|
| 353 |
/// 生成最终事件序列
|
| 354 |
-
pub fn generate_final_events(&mut self,
|
| 355 |
let mut events = Vec::new();
|
| 356 |
|
| 357 |
// 关闭所有未关闭的块
|
|
@@ -380,7 +380,8 @@ impl SseStateManager {
|
|
| 380 |
"stop_sequence": null
|
| 381 |
},
|
| 382 |
"usage": {
|
| 383 |
-
"
|
|
|
|
| 384 |
}
|
| 385 |
}),
|
| 386 |
));
|
|
@@ -399,6 +400,9 @@ impl SseStateManager {
|
|
| 399 |
}
|
| 400 |
}
|
| 401 |
|
|
|
|
|
|
|
|
|
|
| 402 |
/// 流处理上下文
|
| 403 |
pub struct StreamContext {
|
| 404 |
/// SSE 状态管理器
|
|
@@ -407,8 +411,10 @@ pub struct StreamContext {
|
|
| 407 |
pub model: String,
|
| 408 |
/// 消息 ID
|
| 409 |
pub message_id: String,
|
| 410 |
-
/// 输入 tokens
|
| 411 |
pub input_tokens: i32,
|
|
|
|
|
|
|
| 412 |
/// 输出 tokens 累计
|
| 413 |
pub output_tokens: i32,
|
| 414 |
/// 工具块索引映射 (tool_id -> block_index)
|
|
@@ -435,6 +441,7 @@ impl StreamContext {
|
|
| 435 |
model: model.into(),
|
| 436 |
message_id: format!("msg_{}", Uuid::new_v4().to_string().replace('-', "")),
|
| 437 |
input_tokens,
|
|
|
|
| 438 |
output_tokens: 0,
|
| 439 |
tool_block_indices: HashMap::new(),
|
| 440 |
thinking_enabled,
|
|
@@ -514,6 +521,18 @@ impl StreamContext {
|
|
| 514 |
Event::ToolUse(tool_use) => {
|
| 515 |
self.process_tool_use(tool_use)
|
| 516 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 517 |
Event::Error { error_code, error_message } => {
|
| 518 |
tracing::error!("收到错误事件: {} - {}", error_code, error_message);
|
| 519 |
Vec::new()
|
|
@@ -823,8 +842,11 @@ impl StreamContext {
|
|
| 823 |
self.thinking_buffer.clear();
|
| 824 |
}
|
| 825 |
|
|
|
|
|
|
|
|
|
|
| 826 |
// 生成最终事件
|
| 827 |
-
events.extend(self.state_manager.generate_final_events(
|
| 828 |
events
|
| 829 |
}
|
| 830 |
}
|
|
|
|
| 351 |
}
|
| 352 |
|
| 353 |
/// 生成最终事件序列
|
| 354 |
+
pub fn generate_final_events(&mut self, input_tokens: i32) -> Vec<SseEvent> {
|
| 355 |
let mut events = Vec::new();
|
| 356 |
|
| 357 |
// 关闭所有未关闭的块
|
|
|
|
| 380 |
"stop_sequence": null
|
| 381 |
},
|
| 382 |
"usage": {
|
| 383 |
+
"input_tokens": input_tokens,
|
| 384 |
+
"output_tokens": 1
|
| 385 |
}
|
| 386 |
}),
|
| 387 |
));
|
|
|
|
| 400 |
}
|
| 401 |
}
|
| 402 |
|
| 403 |
+
/// 上下文窗口大小(200k tokens)
|
| 404 |
+
const CONTEXT_WINDOW_SIZE: i32 = 200_000;
|
| 405 |
+
|
| 406 |
/// 流处理上下文
|
| 407 |
pub struct StreamContext {
|
| 408 |
/// SSE 状态管理器
|
|
|
|
| 411 |
pub model: String,
|
| 412 |
/// 消息 ID
|
| 413 |
pub message_id: String,
|
| 414 |
+
/// 输入 tokens(估算值)
|
| 415 |
pub input_tokens: i32,
|
| 416 |
+
/// 从 contextUsageEvent 计算的实际输入 tokens
|
| 417 |
+
pub context_input_tokens: Option<i32>,
|
| 418 |
/// 输出 tokens 累计
|
| 419 |
pub output_tokens: i32,
|
| 420 |
/// 工具块索引映射 (tool_id -> block_index)
|
|
|
|
| 441 |
model: model.into(),
|
| 442 |
message_id: format!("msg_{}", Uuid::new_v4().to_string().replace('-', "")),
|
| 443 |
input_tokens,
|
| 444 |
+
context_input_tokens: None,
|
| 445 |
output_tokens: 0,
|
| 446 |
tool_block_indices: HashMap::new(),
|
| 447 |
thinking_enabled,
|
|
|
|
| 521 |
Event::ToolUse(tool_use) => {
|
| 522 |
self.process_tool_use(tool_use)
|
| 523 |
}
|
| 524 |
+
Event::ContextUsage(context_usage) => {
|
| 525 |
+
// 从上下文使用百分比计算实际的 input_tokens
|
| 526 |
+
// 公式: percentage * 200000 / 100 = percentage * 2000
|
| 527 |
+
let actual_input_tokens = (context_usage.context_usage_percentage * (CONTEXT_WINDOW_SIZE as f64) / 100.0) as i32;
|
| 528 |
+
self.context_input_tokens = Some(actual_input_tokens);
|
| 529 |
+
tracing::debug!(
|
| 530 |
+
"收到 contextUsageEvent: {}%, 计算 input_tokens: {}",
|
| 531 |
+
context_usage.context_usage_percentage,
|
| 532 |
+
actual_input_tokens
|
| 533 |
+
);
|
| 534 |
+
Vec::new()
|
| 535 |
+
}
|
| 536 |
Event::Error { error_code, error_message } => {
|
| 537 |
tracing::error!("收到错误事件: {} - {}", error_code, error_message);
|
| 538 |
Vec::new()
|
|
|
|
| 842 |
self.thinking_buffer.clear();
|
| 843 |
}
|
| 844 |
|
| 845 |
+
// 使用从 contextUsageEvent 计算的 input_tokens,如果没有则使用估算值
|
| 846 |
+
let final_input_tokens = self.context_input_tokens.unwrap_or(self.input_tokens);
|
| 847 |
+
|
| 848 |
// 生成最终事件
|
| 849 |
+
events.extend(self.state_manager.generate_final_events(final_input_tokens));
|
| 850 |
events
|
| 851 |
}
|
| 852 |
}
|