hank9999 commited on
Commit
b8f39ea
·
1 Parent(s): c00d86d

fix(stream): 使用 contextUsageEvent 计算实际输入 tokens,调整事件生成逻辑

Browse files
src/anthropic/handlers.rs CHANGED
@@ -306,6 +306,9 @@ fn create_sse_stream(
306
  initial_stream.chain(processing_stream)
307
  }
308
 
 
 
 
309
  /// 处理非流式请求
310
  async fn handle_non_stream_request(
311
  provider: std::sync::Arc<tokio::sync::Mutex<crate::kiro::provider::KiroProvider>>,
@@ -356,6 +359,8 @@ async fn handle_non_stream_request(
356
  let mut tool_uses: Vec<serde_json::Value> = Vec::new();
357
  let mut has_tool_use = false;
358
  let mut stop_reason = "end_turn".to_string();
 
 
359
 
360
  // 收集工具调用的增量 JSON
361
  let mut tool_json_buffers: std::collections::HashMap<String, String> = std::collections::HashMap::new();
@@ -396,6 +401,17 @@ async fn handle_non_stream_request(
396
  }));
397
  }
398
  }
 
 
 
 
 
 
 
 
 
 
 
399
  Event::Exception { exception_type, .. } => {
400
  if exception_type == "ContentLengthExceededException" {
401
  stop_reason = "max_tokens".to_string();
@@ -431,6 +447,9 @@ async fn handle_non_stream_request(
431
  // 估算输出 tokens
432
  let output_tokens = token::estimate_output_tokens(&content);
433
 
 
 
 
434
  // 构建 Anthropic 响应
435
  let response_body = json!({
436
  "id": format!("msg_{}", Uuid::new_v4().to_string().replace('-', "")),
@@ -441,7 +460,7 @@ async fn handle_non_stream_request(
441
  "stop_reason": stop_reason,
442
  "stop_sequence": null,
443
  "usage": {
444
- "input_tokens": input_tokens,
445
  "output_tokens": output_tokens
446
  }
447
  });
 
306
  initial_stream.chain(processing_stream)
307
  }
308
 
309
+ /// 上下文窗口大小(200k tokens)
310
+ const CONTEXT_WINDOW_SIZE: i32 = 200_000;
311
+
312
  /// 处理非流式请求
313
  async fn handle_non_stream_request(
314
  provider: std::sync::Arc<tokio::sync::Mutex<crate::kiro::provider::KiroProvider>>,
 
359
  let mut tool_uses: Vec<serde_json::Value> = Vec::new();
360
  let mut has_tool_use = false;
361
  let mut stop_reason = "end_turn".to_string();
362
+ // 从 contextUsageEvent 计算的实际输入 tokens
363
+ let mut context_input_tokens: Option<i32> = None;
364
 
365
  // 收集工具调用的增量 JSON
366
  let mut tool_json_buffers: std::collections::HashMap<String, String> = std::collections::HashMap::new();
 
401
  }));
402
  }
403
  }
404
+ Event::ContextUsage(context_usage) => {
405
+ // 从上下文使用百分比计算实际的 input_tokens
406
+ // 公式: percentage * 200000 / 100 = percentage * 2000
407
+ let actual_input_tokens = (context_usage.context_usage_percentage * (CONTEXT_WINDOW_SIZE as f64) / 100.0) as i32;
408
+ context_input_tokens = Some(actual_input_tokens);
409
+ tracing::debug!(
410
+ "收到 contextUsageEvent: {}%, 计算 input_tokens: {}",
411
+ context_usage.context_usage_percentage,
412
+ actual_input_tokens
413
+ );
414
+ }
415
  Event::Exception { exception_type, .. } => {
416
  if exception_type == "ContentLengthExceededException" {
417
  stop_reason = "max_tokens".to_string();
 
447
  // 估算输出 tokens
448
  let output_tokens = token::estimate_output_tokens(&content);
449
 
450
+ // 使用从 contextUsageEvent 计算的 input_tokens,如果没有则使用估算值
451
+ let final_input_tokens = context_input_tokens.unwrap_or(input_tokens);
452
+
453
  // 构建 Anthropic 响应
454
  let response_body = json!({
455
  "id": format!("msg_{}", Uuid::new_v4().to_string().replace('-', "")),
 
460
  "stop_reason": stop_reason,
461
  "stop_sequence": null,
462
  "usage": {
463
+ "input_tokens": final_input_tokens,
464
  "output_tokens": output_tokens
465
  }
466
  });
src/anthropic/stream.rs CHANGED
@@ -351,7 +351,7 @@ impl SseStateManager {
351
  }
352
 
353
  /// 生成最终事件序列
354
- pub fn generate_final_events(&mut self, output_tokens: i32) -> Vec<SseEvent> {
355
  let mut events = Vec::new();
356
 
357
  // 关闭所有未关闭的块
@@ -380,7 +380,8 @@ impl SseStateManager {
380
  "stop_sequence": null
381
  },
382
  "usage": {
383
- "output_tokens": output_tokens
 
384
  }
385
  }),
386
  ));
@@ -399,6 +400,9 @@ impl SseStateManager {
399
  }
400
  }
401
 
 
 
 
402
  /// 流处理上下文
403
  pub struct StreamContext {
404
  /// SSE 状态管理器
@@ -407,8 +411,10 @@ pub struct StreamContext {
407
  pub model: String,
408
  /// 消息 ID
409
  pub message_id: String,
410
- /// 输入 tokens
411
  pub input_tokens: i32,
 
 
412
  /// 输出 tokens 累计
413
  pub output_tokens: i32,
414
  /// 工具块索引映射 (tool_id -> block_index)
@@ -435,6 +441,7 @@ impl StreamContext {
435
  model: model.into(),
436
  message_id: format!("msg_{}", Uuid::new_v4().to_string().replace('-', "")),
437
  input_tokens,
 
438
  output_tokens: 0,
439
  tool_block_indices: HashMap::new(),
440
  thinking_enabled,
@@ -514,6 +521,18 @@ impl StreamContext {
514
  Event::ToolUse(tool_use) => {
515
  self.process_tool_use(tool_use)
516
  }
 
 
 
 
 
 
 
 
 
 
 
 
517
  Event::Error { error_code, error_message } => {
518
  tracing::error!("收到错误事件: {} - {}", error_code, error_message);
519
  Vec::new()
@@ -823,8 +842,11 @@ impl StreamContext {
823
  self.thinking_buffer.clear();
824
  }
825
 
 
 
 
826
  // 生成最终事件
827
- events.extend(self.state_manager.generate_final_events(self.output_tokens));
828
  events
829
  }
830
  }
 
351
  }
352
 
353
  /// 生成最终事件序列
354
+ pub fn generate_final_events(&mut self, input_tokens: i32) -> Vec<SseEvent> {
355
  let mut events = Vec::new();
356
 
357
  // 关闭所有未关闭的块
 
380
  "stop_sequence": null
381
  },
382
  "usage": {
383
+ "input_tokens": input_tokens,
384
+ "output_tokens": 1
385
  }
386
  }),
387
  ));
 
400
  }
401
  }
402
 
403
+ /// 上下文窗口大小(200k tokens)
404
+ const CONTEXT_WINDOW_SIZE: i32 = 200_000;
405
+
406
  /// 流处理上下文
407
  pub struct StreamContext {
408
  /// SSE 状态管理器
 
411
  pub model: String,
412
  /// 消息 ID
413
  pub message_id: String,
414
+ /// 输入 tokens(估算值)
415
  pub input_tokens: i32,
416
+ /// 从 contextUsageEvent 计算的实际输入 tokens
417
+ pub context_input_tokens: Option<i32>,
418
  /// 输出 tokens 累计
419
  pub output_tokens: i32,
420
  /// 工具块索引映射 (tool_id -> block_index)
 
441
  model: model.into(),
442
  message_id: format!("msg_{}", Uuid::new_v4().to_string().replace('-', "")),
443
  input_tokens,
444
+ context_input_tokens: None,
445
  output_tokens: 0,
446
  tool_block_indices: HashMap::new(),
447
  thinking_enabled,
 
521
  Event::ToolUse(tool_use) => {
522
  self.process_tool_use(tool_use)
523
  }
524
+ Event::ContextUsage(context_usage) => {
525
+ // 从上下文使用百分比计算实际的 input_tokens
526
+ // 公式: percentage * 200000 / 100 = percentage * 2000
527
+ let actual_input_tokens = (context_usage.context_usage_percentage * (CONTEXT_WINDOW_SIZE as f64) / 100.0) as i32;
528
+ self.context_input_tokens = Some(actual_input_tokens);
529
+ tracing::debug!(
530
+ "收到 contextUsageEvent: {}%, 计算 input_tokens: {}",
531
+ context_usage.context_usage_percentage,
532
+ actual_input_tokens
533
+ );
534
+ Vec::new()
535
+ }
536
  Event::Error { error_code, error_message } => {
537
  tracing::error!("收到错误事件: {} - {}", error_code, error_message);
538
  Vec::new()
 
842
  self.thinking_buffer.clear();
843
  }
844
 
845
+ // 使用从 contextUsageEvent 计算的 input_tokens,如果没有则使用估算值
846
+ let final_input_tokens = self.context_input_tokens.unwrap_or(self.input_tokens);
847
+
848
  // 生成最终事件
849
+ events.extend(self.state_manager.generate_final_events(final_input_tokens));
850
  events
851
  }
852
  }