hank9999 commited on
Commit
f0d4864
·
1 Parent(s): 256ba56

fix: 为 count_all_tokens 增加 model 参数,支持动态指定模型

Browse files
src/anthropic/handlers.rs CHANGED
@@ -143,7 +143,7 @@ pub async fn post_messages(
143
  tracing::debug!("Kiro request body: {}", request_body);
144
 
145
  // 估算输入 tokens
146
- let input_tokens = token::count_all_tokens(payload.system, payload.messages, payload.tools) as i32;
147
 
148
  // 检查是否启用了thinking
149
  let thinking_enabled = payload.thinking
@@ -440,7 +440,7 @@ pub async fn count_tokens(JsonExtractor(payload): JsonExtractor<CountTokensReque
440
  "Received POST /v1/messages/count_tokens request"
441
  );
442
 
443
- let total_tokens = token::count_all_tokens(payload.system, payload.messages, payload.tools) as i32;
444
 
445
  Json(CountTokensResponse {
446
  input_tokens: total_tokens.max(1) as i32,
 
143
  tracing::debug!("Kiro request body: {}", request_body);
144
 
145
  // 估算输入 tokens
146
+ let input_tokens = token::count_all_tokens(payload.model.clone(), payload.system, payload.messages, payload.tools) as i32;
147
 
148
  // 检查是否启用了thinking
149
  let thinking_enabled = payload.thinking
 
440
  "Received POST /v1/messages/count_tokens request"
441
  );
442
 
443
+ let total_tokens = token::count_all_tokens(payload.model, payload.system, payload.messages, payload.tools) as i32;
444
 
445
  Json(CountTokensResponse {
446
  input_tokens: total_tokens.max(1) as i32,
src/anthropic/token.rs CHANGED
@@ -100,14 +100,14 @@ pub fn count_tokens(text: &str) -> u64 {
100
  /// 估算请求的输入 tokens
101
  ///
102
  /// 优先调用远程 API,失败时回退到本地计算
103
- pub(crate) fn count_all_tokens(system: Option<Vec<SystemMessage>>, messages: Vec<Message>, tools: Option<Vec<Tool>>) -> u64 {
104
  // 检查是否配置了远程 API
105
  if let Some(config) = get_config() {
106
  if let Some(api_url) = &config.api_url {
107
  // 尝试调用远程 API
108
  let result = tokio::task::block_in_place(|| {
109
  tokio::runtime::Handle::current().block_on(
110
- call_remote_count_tokens(api_url, config, &system, &messages, &tools)
111
  )
112
  });
113
 
@@ -131,6 +131,7 @@ pub(crate) fn count_all_tokens(system: Option<Vec<SystemMessage>>, messages: Vec
131
  async fn call_remote_count_tokens(
132
  api_url: &str,
133
  config: &CountTokensConfig,
 
134
  system: &Option<Vec<SystemMessage>>,
135
  messages: &Vec<Message>,
136
  tools: &Option<Vec<Tool>>,
@@ -139,7 +140,7 @@ async fn call_remote_count_tokens(
139
 
140
  // 构建请求体
141
  let request = CountTokensRequest {
142
- model: "claude-sonnet-4-5-20250929".to_string(), // 模型名称用于 token 计算
143
  messages: messages.clone(),
144
  system: system.clone(),
145
  tools: tools.clone(),
 
100
  /// 估算请求的输入 tokens
101
  ///
102
  /// 优先调用远程 API,失败时回退到本地计算
103
+ pub(crate) fn count_all_tokens(model: String, system: Option<Vec<SystemMessage>>, messages: Vec<Message>, tools: Option<Vec<Tool>>) -> u64 {
104
  // 检查是否配置了远程 API
105
  if let Some(config) = get_config() {
106
  if let Some(api_url) = &config.api_url {
107
  // 尝试调用远程 API
108
  let result = tokio::task::block_in_place(|| {
109
  tokio::runtime::Handle::current().block_on(
110
+ call_remote_count_tokens(api_url, config, model, &system, &messages, &tools)
111
  )
112
  });
113
 
 
131
  async fn call_remote_count_tokens(
132
  api_url: &str,
133
  config: &CountTokensConfig,
134
+ model: String,
135
  system: &Option<Vec<SystemMessage>>,
136
  messages: &Vec<Message>,
137
  tools: &Option<Vec<Tool>>,
 
140
 
141
  // 构建请求体
142
  let request = CountTokensRequest {
143
+ model: model, // 模型名称用于 token 计算
144
  messages: messages.clone(),
145
  system: system.clone(),
146
  tools: tools.clone(),