hank9999 commited on
Commit
2b792a7
·
1 Parent(s): bf70739

Revert "refactor: 由于模型事件中实现了精准上下文统计, 移除 count_tokens 强制 CC 使用会话测试"

Browse files

This reverts commit aa2a16e9c76091a0a648d6fa738e2f667f28e003.

README.md CHANGED
@@ -17,6 +17,7 @@
17
  |------|------|-------------|
18
  | `/v1/models` | GET | 获取可用模型列表 |
19
  | `/v1/messages` | POST | 创建消息(对话) |
 
20
 
21
  ## 快速开始
22
 
@@ -39,7 +40,10 @@ cargo build --release
39
  "kiroVersion": "0.8.0",
40
  "machineId": "如果你需要自定义机器码请将64位机器码填到这里", // 不是标准格式会自动忽略, 自动生成
41
  "systemVersion": "darwin#24.6.0",
42
- "nodeVersion": "22.21.1"
 
 
 
43
  }
44
  ```
45
 
@@ -99,6 +103,9 @@ curl http://127.0.0.1:8990/v1/messages \
99
  | `machineId` | string | - | 自定义机器码(64位十六进制)不定义则自动生成 |
100
  | `systemVersion` | string | 随机 | 系统版本标识 |
101
  | `nodeVersion` | string | `22.21.1` | Node.js 版本标识 |
 
 
 
102
 
103
  ### credentials.json
104
 
 
17
  |------|------|-------------|
18
  | `/v1/models` | GET | 获取可用模型列表 |
19
  | `/v1/messages` | POST | 创建消息(对话) |
20
+ | `/v1/messages/count_tokens` | POST | 估算 Token 数量 |
21
 
22
  ## 快速开始
23
 
 
40
  "kiroVersion": "0.8.0",
41
  "machineId": "如果你需要自定义机器码请将64位机器码填到这里", // 不是标准格式会自动忽略, 自动生成
42
  "systemVersion": "darwin#24.6.0",
43
+ "nodeVersion": "22.21.1",
44
+ "countTokensApiUrl": "https://api.example.com/v1/messages/count_tokens", // 可选,外部 count_tokens API 地址
45
+ "countTokensApiKey": "sk-your-count-tokens-api-key", // 可选,外部 API 密钥
46
+ "countTokensAuthType": "x-api-key" // 可选,认证类型:x-api-key 或 bearer
47
  }
48
  ```
49
 
 
103
  | `machineId` | string | - | 自定义机器码(64位十六进制)不定义则自动生成 |
104
  | `systemVersion` | string | 随机 | 系统版本标识 |
105
  | `nodeVersion` | string | `22.21.1` | Node.js 版本标识 |
106
+ | `countTokensApiUrl` | string | - | 外部 count_tokens API 地址(可选) |
107
+ | `countTokensApiKey` | string | - | 外部 count_tokens API 密钥(可选) |
108
+ | `countTokensAuthType` | string | `x-api-key` | 外部 API 认证类型:`x-api-key` 或 `bearer` |
109
 
110
  ### credentials.json
111
 
config.example.json CHANGED
@@ -6,5 +6,8 @@
6
  "kiroVersion": "0.8.0",
7
  "machineId": "如果你需要自定义机器码请将64位机器码填到这里",
8
  "systemVersion": "darwin#24.6.0",
9
- "nodeVersion": "22.21.1"
 
 
 
10
  }
 
6
  "kiroVersion": "0.8.0",
7
  "machineId": "如果你需要自定义机器码请将64位机器码填到这里",
8
  "systemVersion": "darwin#24.6.0",
9
+ "nodeVersion": "22.21.1",
10
+ "countTokensApiUrl": "https://api.example.com/v1/messages/count_tokens",
11
+ "countTokensApiKey": "sk-your-count-tokens-api-key",
12
+ "countTokensAuthType": "x-api-key"
13
  }
src/anthropic/handlers.rs CHANGED
@@ -23,7 +23,9 @@ use crate::kiro::parser::decoder::EventStreamDecoder;
23
  use super::converter::{convert_request, ConversionError};
24
  use super::middleware::AppState;
25
  use super::stream::{SseEvent, StreamContext};
26
- use super::types::{ErrorResponse, MessagesRequest, Model, ModelsResponse};
 
 
27
 
28
  /// GET /v1/models
29
  ///
@@ -468,3 +470,19 @@ async fn handle_non_stream_request(
468
 
469
 
470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  use super::converter::{convert_request, ConversionError};
24
  use super::middleware::AppState;
25
  use super::stream::{SseEvent, StreamContext};
26
+ use super::types::{
27
+ CountTokensRequest, CountTokensResponse, ErrorResponse, MessagesRequest, Model, ModelsResponse,
28
+ };
29
 
30
  /// GET /v1/models
31
  ///
 
470
 
471
 
472
 
473
+ /// POST /v1/messages/count_tokens
474
+ ///
475
+ /// 计算消息的 token 数量
476
+ pub async fn count_tokens(JsonExtractor(payload): JsonExtractor<CountTokensRequest>) -> impl IntoResponse {
477
+ tracing::info!(
478
+ model = %payload.model,
479
+ message_count = %payload.messages.len(),
480
+ "Received POST /v1/messages/count_tokens request"
481
+ );
482
+
483
+ let total_tokens = token::count_all_tokens(payload.model, payload.system, payload.messages, payload.tools) as i32;
484
+
485
+ Json(CountTokensResponse {
486
+ input_tokens: total_tokens.max(1) as i32,
487
+ })
488
+ }
src/anthropic/mod.rs CHANGED
@@ -5,6 +5,7 @@
5
  //! # 支持的端点
6
  //! - `GET /v1/models` - 获取可用模型列表
7
  //! - `POST /v1/messages` - 创建消息(对话)
 
8
  //!
9
  //! # 使用示例
10
  //! ```rust,ignore
 
5
  //! # 支持的端点
6
  //! - `GET /v1/models` - 获取可用模型列表
7
  //! - `POST /v1/messages` - 创建消息(对话)
8
+ //! - `POST /v1/messages/count_tokens` - 计算 token 数量
9
  //!
10
  //! # 使用示例
11
  //! ```rust,ignore
src/anthropic/router.rs CHANGED
@@ -9,7 +9,7 @@ use axum::{
9
  use crate::kiro::provider::KiroProvider;
10
 
11
  use super::{
12
- handlers::{get_models, post_messages},
13
  middleware::{auth_middleware, cors_layer, AppState},
14
  };
15
 
@@ -18,6 +18,7 @@ use super::{
18
  /// # 端点
19
  /// - `GET /v1/models` - 获取可用模型列表
20
  /// - `POST /v1/messages` - 创建消息(对话)
 
21
  ///
22
  /// # 认证
23
  /// 所有 `/v1` 路径需要 API Key 认证,支持:
@@ -46,6 +47,7 @@ pub fn create_router_with_provider(
46
  let v1_routes = Router::new()
47
  .route("/models", get(get_models))
48
  .route("/messages", post(post_messages))
 
49
  .layer(middleware::from_fn_with_state(
50
  state.clone(),
51
  auth_middleware,
 
9
  use crate::kiro::provider::KiroProvider;
10
 
11
  use super::{
12
+ handlers::{count_tokens, get_models, post_messages},
13
  middleware::{auth_middleware, cors_layer, AppState},
14
  };
15
 
 
18
  /// # 端点
19
  /// - `GET /v1/models` - 获取可用模型列表
20
  /// - `POST /v1/messages` - 创建消息(对话)
21
+ /// - `POST /v1/messages/count_tokens` - 计算 token 数量
22
  ///
23
  /// # 认证
24
  /// 所有 `/v1` 路径需要 API Key 认证,支持:
 
47
  let v1_routes = Router::new()
48
  .route("/models", get(get_models))
49
  .route("/messages", post(post_messages))
50
+ .route("/messages/count_tokens", post(count_tokens))
51
  .layer(middleware::from_fn_with_state(
52
  state.clone(),
53
  auth_middleware,
src/anthropic/token.rs CHANGED
@@ -7,7 +7,34 @@
7
  //! - 西文字符:每个计 1 个字符单位
8
  //! - 4 个字符单位 = 1 token(四舍五入)
9
 
10
- use crate::anthropic::types::{Message, SystemMessage, Tool};
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  /// 判断字符是否为非西文字符
@@ -71,7 +98,84 @@ pub fn count_tokens(text: &str) -> u64 {
71
 
72
 
73
  /// 估算请求的输入 tokens
74
- pub(crate) fn count_all_tokens(_model: String, system: Option<Vec<SystemMessage>>, messages: Vec<Message>, tools: Option<Vec<Tool>>) -> u64 {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  let mut total = 0;
76
 
77
  // 系统消息
 
7
  //! - 西文字符:每个计 1 个字符单位
8
  //! - 4 个字符单位 = 1 token(四舍五入)
9
 
10
+ use std::sync::OnceLock;
11
+ use crate::anthropic::types::{CountTokensRequest, CountTokensResponse, Message, SystemMessage, Tool};
12
+
13
+ /// Count Tokens API 配置
14
+ #[derive(Clone, Default)]
15
+ pub struct CountTokensConfig {
16
+ /// 外部 count_tokens API 地址
17
+ pub api_url: Option<String>,
18
+ /// count_tokens API 密钥
19
+ pub api_key: Option<String>,
20
+ /// count_tokens API 认证类型("x-api-key" 或 "bearer")
21
+ pub auth_type: String,
22
+ }
23
+
24
+ /// 全局配置存储
25
+ static COUNT_TOKENS_CONFIG: OnceLock<CountTokensConfig> = OnceLock::new();
26
+
27
+ /// 初始化 count_tokens 配置
28
+ ///
29
+ /// 应在应用启动时调用一次
30
+ pub fn init_config(config: CountTokensConfig) {
31
+ let _ = COUNT_TOKENS_CONFIG.set(config);
32
+ }
33
+
34
+ /// 获取配置
35
+ fn get_config() -> Option<&'static CountTokensConfig> {
36
+ COUNT_TOKENS_CONFIG.get()
37
+ }
38
 
39
 
40
  /// 判断字符是否为非西文字符
 
98
 
99
 
100
  /// 估算请求的输入 tokens
101
+ ///
102
+ /// 优先调用远程 API,失败时回退到本地计算
103
+ pub(crate) fn count_all_tokens(model: String, system: Option<Vec<SystemMessage>>, messages: Vec<Message>, tools: Option<Vec<Tool>>) -> u64 {
104
+ // 检查是否配置了远程 API
105
+ if let Some(config) = get_config() {
106
+ if let Some(api_url) = &config.api_url {
107
+ // 尝试调用远程 API
108
+ let result = tokio::task::block_in_place(|| {
109
+ tokio::runtime::Handle::current().block_on(
110
+ call_remote_count_tokens(api_url, config, model, &system, &messages, &tools)
111
+ )
112
+ });
113
+
114
+ match result {
115
+ Ok(tokens) => {
116
+ tracing::debug!("远程 count_tokens API 返回: {}", tokens);
117
+ return tokens;
118
+ }
119
+ Err(e) => {
120
+ tracing::warn!("远程 count_tokens API 调用失败,回退到本地计算: {}", e);
121
+ }
122
+ }
123
+ }
124
+ }
125
+
126
+ // 本地计算
127
+ count_all_tokens_local(system, messages, tools)
128
+ }
129
+
130
+ /// 调用远程 count_tokens API
131
+ async fn call_remote_count_tokens(
132
+ api_url: &str,
133
+ config: &CountTokensConfig,
134
+ model: String,
135
+ system: &Option<Vec<SystemMessage>>,
136
+ messages: &Vec<Message>,
137
+ tools: &Option<Vec<Tool>>,
138
+ ) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
139
+ let client = reqwest::Client::new();
140
+
141
+ // 构建请求体
142
+ let request = CountTokensRequest {
143
+ model: model, // 模型名称用于 token 计算
144
+ messages: messages.clone(),
145
+ system: system.clone(),
146
+ tools: tools.clone(),
147
+ };
148
+
149
+ // 构建请求
150
+ let mut req_builder = client.post(api_url);
151
+
152
+ // 设置认证头
153
+ if let Some(api_key) = &config.api_key {
154
+ if config.auth_type == "bearer" {
155
+ req_builder = req_builder.header("Authorization", format!("Bearer {}", api_key));
156
+ } else {
157
+ req_builder = req_builder.header("x-api-key", api_key);
158
+ }
159
+ }
160
+
161
+ // 发送请求
162
+ let response = req_builder
163
+ .header("Content-Type", "application/json")
164
+ .json(&request)
165
+ .timeout(std::time::Duration::from_secs(5))
166
+ .send()
167
+ .await?;
168
+
169
+ if !response.status().is_success() {
170
+ return Err(format!("API 返回错误状态: {}", response.status()).into());
171
+ }
172
+
173
+ let result: CountTokensResponse = response.json().await?;
174
+ Ok(result.input_tokens as u64)
175
+ }
176
+
177
+ /// 本地计算请求的输入 tokens
178
+ fn count_all_tokens_local(system: Option<Vec<SystemMessage>>, messages: Vec<Message>, tools: Option<Vec<Tool>>) -> u64 {
179
  let mut total = 0;
180
 
181
  // 系统消息
src/anthropic/types.rs CHANGED
@@ -150,3 +150,22 @@ pub struct ImageSource {
150
  pub media_type: String,
151
  pub data: String,
152
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  pub media_type: String,
151
  pub data: String,
152
  }
153
+
154
+ // === Count Tokens 端点类型 ===
155
+
156
+ /// Token 计数请求
157
+ #[derive(Debug, Serialize, Deserialize)]
158
+ pub struct CountTokensRequest {
159
+ pub model: String,
160
+ pub messages: Vec<Message>,
161
+ #[serde(skip_serializing_if = "Option::is_none")]
162
+ pub system: Option<Vec<SystemMessage>>,
163
+ #[serde(skip_serializing_if = "Option::is_none")]
164
+ pub tools: Option<Vec<Tool>>,
165
+ }
166
+
167
+ /// Token 计数响应
168
+ #[derive(Debug, Serialize, Deserialize)]
169
+ pub struct CountTokensResponse {
170
+ pub input_tokens: i32,
171
+ }
src/main.rs CHANGED
@@ -48,6 +48,13 @@ async fn main() {
48
  let token_manager = TokenManager::new(config.clone(), credentials.clone());
49
  let kiro_provider = KiroProvider::new(token_manager);
50
 
 
 
 
 
 
 
 
51
  // 构建路由(从凭据获取 profile_arn)
52
  let app = anthropic::create_router_with_provider(&api_key, Some(kiro_provider), credentials.profile_arn.clone());
53
 
@@ -58,6 +65,7 @@ async fn main() {
58
  tracing::info!("可用 API:");
59
  tracing::info!(" GET /v1/models");
60
  tracing::info!(" POST /v1/messages");
 
61
 
62
  let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
63
  axum::serve(listener, app).await.unwrap();
 
48
  let token_manager = TokenManager::new(config.clone(), credentials.clone());
49
  let kiro_provider = KiroProvider::new(token_manager);
50
 
51
+ // 初始化 count_tokens 配置
52
+ anthropic::token::init_config(anthropic::token::CountTokensConfig {
53
+ api_url: config.count_tokens_api_url.clone(),
54
+ api_key: config.count_tokens_api_key.clone(),
55
+ auth_type: config.count_tokens_auth_type.clone(),
56
+ });
57
+
58
  // 构建路由(从凭据获取 profile_arn)
59
  let app = anthropic::create_router_with_provider(&api_key, Some(kiro_provider), credentials.profile_arn.clone());
60
 
 
65
  tracing::info!("可用 API:");
66
  tracing::info!(" GET /v1/models");
67
  tracing::info!(" POST /v1/messages");
68
+ tracing::info!(" POST /v1/messages/count_tokens");
69
 
70
  let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
71
  axum::serve(listener, app).await.unwrap();
src/model/config.rs CHANGED
@@ -29,6 +29,18 @@ pub struct Config {
29
 
30
  #[serde(default = "default_node_version")]
31
  pub node_version: String,
 
 
 
 
 
 
 
 
 
 
 
 
32
  }
33
 
34
  fn default_host() -> String {
@@ -56,6 +68,10 @@ fn default_node_version() -> String {
56
  "22.21.1".to_string()
57
  }
58
 
 
 
 
 
59
  impl Default for Config {
60
  fn default() -> Self {
61
  Self {
@@ -67,6 +83,9 @@ impl Default for Config {
67
  api_key: None,
68
  system_version: default_system_version(),
69
  node_version: default_node_version(),
 
 
 
70
  }
71
  }
72
  }
 
29
 
30
  #[serde(default = "default_node_version")]
31
  pub node_version: String,
32
+
33
+ /// 外部 count_tokens API 地址(可选)
34
+ #[serde(default)]
35
+ pub count_tokens_api_url: Option<String>,
36
+
37
+ /// count_tokens API 密钥(可选)
38
+ #[serde(default)]
39
+ pub count_tokens_api_key: Option<String>,
40
+
41
+ /// count_tokens API 认证类型(可选,"x-api-key" 或 "bearer",默认 "x-api-key")
42
+ #[serde(default = "default_count_tokens_auth_type")]
43
+ pub count_tokens_auth_type: String,
44
  }
45
 
46
  fn default_host() -> String {
 
68
  "22.21.1".to_string()
69
  }
70
 
71
+ fn default_count_tokens_auth_type() -> String {
72
+ "x-api-key".to_string()
73
+ }
74
+
75
  impl Default for Config {
76
  fn default() -> Self {
77
  Self {
 
83
  api_key: None,
84
  system_version: default_system_version(),
85
  node_version: default_node_version(),
86
+ count_tokens_api_url: None,
87
+ count_tokens_api_key: None,
88
+ count_tokens_auth_type: default_count_tokens_auth_type(),
89
  }
90
  }
91
  }