hank9999 commited on
Commit
256ba56
·
1 Parent(s): d40d682

feat: 添加远程 count_tokens API 支持并完善相关配置逻辑

Browse files
Files changed (3) hide show
  1. src/anthropic/token.rs +104 -1
  2. src/anthropic/types.rs +8 -6
  3. src/main.rs +7 -0
src/anthropic/token.rs CHANGED
@@ -7,7 +7,34 @@
7
  //! - 西文字符:每个计 1 个字符单位
8
  //! - 4 个字符单位 = 1 token(四舍五入)
9
 
10
- use crate::anthropic::types::{Message, SystemMessage, Tool};
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  /// 判断字符是否为非西文字符
@@ -71,7 +98,83 @@ pub fn count_tokens(text: &str) -> u64 {
71
 
72
 
73
  /// 估算请求的输入 tokens
 
 
74
  pub(crate) fn count_all_tokens(system: Option<Vec<SystemMessage>>, messages: Vec<Message>, tools: Option<Vec<Tool>>) -> u64 {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  let mut total = 0;
76
 
77
  // 系统消息
 
7
  //! - 西文字符:每个计 1 个字符单位
8
  //! - 4 个字符单位 = 1 token(四舍五入)
9
 
10
+ use std::sync::OnceLock;
11
+ use crate::anthropic::types::{CountTokensRequest, CountTokensResponse, Message, SystemMessage, Tool};
12
+
13
+ /// Count Tokens API 配置
14
+ #[derive(Clone, Default)]
15
+ pub struct CountTokensConfig {
16
+ /// 外部 count_tokens API 地址
17
+ pub api_url: Option<String>,
18
+ /// count_tokens API 密钥
19
+ pub api_key: Option<String>,
20
+ /// count_tokens API 认证类型("x-api-key" 或 "bearer")
21
+ pub auth_type: String,
22
+ }
23
+
24
+ /// 全局配置存储
25
+ static COUNT_TOKENS_CONFIG: OnceLock<CountTokensConfig> = OnceLock::new();
26
+
27
+ /// 初始化 count_tokens 配置
28
+ ///
29
+ /// 应在应用启动时调用一次
30
+ pub fn init_config(config: CountTokensConfig) {
31
+ let _ = COUNT_TOKENS_CONFIG.set(config);
32
+ }
33
+
34
+ /// 获取配置
35
+ fn get_config() -> Option<&'static CountTokensConfig> {
36
+ COUNT_TOKENS_CONFIG.get()
37
+ }
38
 
39
 
40
  /// 判断字符是否为非西文字符
 
98
 
99
 
100
  /// 估算请求的输入 tokens
101
+ ///
102
+ /// 优先调用远程 API,失败时回退到本地计算
103
  pub(crate) fn count_all_tokens(system: Option<Vec<SystemMessage>>, messages: Vec<Message>, tools: Option<Vec<Tool>>) -> u64 {
104
+ // 检查是否配置了远程 API
105
+ if let Some(config) = get_config() {
106
+ if let Some(api_url) = &config.api_url {
107
+ // 尝试调用远程 API
108
+ let result = tokio::task::block_in_place(|| {
109
+ tokio::runtime::Handle::current().block_on(
110
+ call_remote_count_tokens(api_url, config, &system, &messages, &tools)
111
+ )
112
+ });
113
+
114
+ match result {
115
+ Ok(tokens) => {
116
+ tracing::debug!("远程 count_tokens API 返回: {}", tokens);
117
+ return tokens;
118
+ }
119
+ Err(e) => {
120
+ tracing::warn!("远程 count_tokens API 调用失败,回退到本地计算: {}", e);
121
+ }
122
+ }
123
+ }
124
+ }
125
+
126
+ // 本地计算
127
+ count_all_tokens_local(system, messages, tools)
128
+ }
129
+
130
+ /// 调用远程 count_tokens API
131
+ async fn call_remote_count_tokens(
132
+ api_url: &str,
133
+ config: &CountTokensConfig,
134
+ system: &Option<Vec<SystemMessage>>,
135
+ messages: &Vec<Message>,
136
+ tools: &Option<Vec<Tool>>,
137
+ ) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
138
+ let client = reqwest::Client::new();
139
+
140
+ // 构建请求体
141
+ let request = CountTokensRequest {
142
+ model: "claude-sonnet-4-5-20250929".to_string(), // 模型名称用于 token 计算
143
+ messages: messages.clone(),
144
+ system: system.clone(),
145
+ tools: tools.clone(),
146
+ };
147
+
148
+ // 构建请求
149
+ let mut req_builder = client.post(api_url);
150
+
151
+ // 设置认证头
152
+ if let Some(api_key) = &config.api_key {
153
+ if config.auth_type == "bearer" {
154
+ req_builder = req_builder.header("Authorization", format!("Bearer {}", api_key));
155
+ } else {
156
+ req_builder = req_builder.header("x-api-key", api_key);
157
+ }
158
+ }
159
+
160
+ // 发送请求
161
+ let response = req_builder
162
+ .header("Content-Type", "application/json")
163
+ .json(&request)
164
+ .timeout(std::time::Duration::from_secs(5))
165
+ .send()
166
+ .await?;
167
+
168
+ if !response.status().is_success() {
169
+ return Err(format!("API 返回错误状态: {}", response.status()).into());
170
+ }
171
+
172
+ let result: CountTokensResponse = response.json().await?;
173
+ Ok(result.input_tokens as u64)
174
+ }
175
+
176
+ /// 本地计算请求的输入 tokens
177
+ fn count_all_tokens_local(system: Option<Vec<SystemMessage>>, messages: Vec<Message>, tools: Option<Vec<Tool>>) -> u64 {
178
  let mut total = 0;
179
 
180
  // 系统消息
src/anthropic/types.rs CHANGED
@@ -98,7 +98,7 @@ pub struct MessagesRequest {
98
  }
99
 
100
  /// 消息
101
- #[derive(Debug, Deserialize)]
102
  pub struct Message {
103
  pub role: String,
104
  /// 可以是 string 或 ContentBlock 数组
@@ -106,13 +106,13 @@ pub struct Message {
106
  }
107
 
108
  /// 系统消息
109
- #[derive(Debug, Deserialize)]
110
  pub struct SystemMessage {
111
- pub text: String
112
  }
113
 
114
  /// 工具定义
115
- #[derive(Debug, Deserialize)]
116
  pub struct Tool {
117
  pub name: String,
118
  pub description: String,
@@ -154,16 +154,18 @@ pub struct ImageSource {
154
  // === Count Tokens 端点类型 ===
155
 
156
  /// Token 计数请求
157
- #[derive(Debug, Deserialize)]
158
  pub struct CountTokensRequest {
159
  pub model: String,
160
  pub messages: Vec<Message>,
 
161
  pub system: Option<Vec<SystemMessage>>,
 
162
  pub tools: Option<Vec<Tool>>,
163
  }
164
 
165
  /// Token 计数响应
166
- #[derive(Debug, Serialize)]
167
  pub struct CountTokensResponse {
168
  pub input_tokens: i32,
169
  }
 
98
  }
99
 
100
  /// 消息
101
+ #[derive(Debug, Clone, Deserialize, Serialize)]
102
  pub struct Message {
103
  pub role: String,
104
  /// 可以是 string 或 ContentBlock 数组
 
106
  }
107
 
108
  /// 系统消息
109
+ #[derive(Debug, Clone, Deserialize, Serialize)]
110
  pub struct SystemMessage {
111
+ pub text: String,
112
  }
113
 
114
  /// 工具定义
115
+ #[derive(Debug, Clone, Deserialize, Serialize)]
116
  pub struct Tool {
117
  pub name: String,
118
  pub description: String,
 
154
  // === Count Tokens 端点类型 ===
155
 
156
  /// Token 计数请求
157
+ #[derive(Debug, Serialize, Deserialize)]
158
  pub struct CountTokensRequest {
159
  pub model: String,
160
  pub messages: Vec<Message>,
161
+ #[serde(skip_serializing_if = "Option::is_none")]
162
  pub system: Option<Vec<SystemMessage>>,
163
+ #[serde(skip_serializing_if = "Option::is_none")]
164
  pub tools: Option<Vec<Tool>>,
165
  }
166
 
167
  /// Token 计数响应
168
+ #[derive(Debug, Serialize, Deserialize)]
169
  pub struct CountTokensResponse {
170
  pub input_tokens: i32,
171
  }
src/main.rs CHANGED
@@ -48,6 +48,13 @@ async fn main() {
48
  let token_manager = TokenManager::new(config.clone(), credentials.clone());
49
  let kiro_provider = KiroProvider::new(token_manager);
50
 
 
 
 
 
 
 
 
51
  // 构建路由(从凭据获取 profile_arn)
52
  let app = anthropic::create_router_with_provider(&api_key, Some(kiro_provider), credentials.profile_arn.clone());
53
 
 
48
  let token_manager = TokenManager::new(config.clone(), credentials.clone());
49
  let kiro_provider = KiroProvider::new(token_manager);
50
 
51
+ // 初始化 count_tokens 配置
52
+ anthropic::token::init_config(anthropic::token::CountTokensConfig {
53
+ api_url: config.count_tokens_api_url.clone(),
54
+ api_key: config.count_tokens_api_key.clone(),
55
+ auth_type: config.count_tokens_auth_type.clone(),
56
+ });
57
+
58
  // 构建路由(从凭据获取 profile_arn)
59
  let app = anthropic::create_router_with_provider(&api_key, Some(kiro_provider), credentials.profile_arn.clone());
60