hank9999 commited on
Commit
aa2a16e
·
1 Parent(s): b8f39ea

refactor: 由于模型事件中实现了精准上下文统计, 移除 count_tokens 强制 CC 使用会话测试

Browse files
README.md CHANGED
@@ -17,7 +17,6 @@
17
  |------|------|-------------|
18
  | `/v1/models` | GET | 获取可用模型列表 |
19
  | `/v1/messages` | POST | 创建消息(对话) |
20
- | `/v1/messages/count_tokens` | POST | 估算 Token 数量 |
21
 
22
  ## 快速开始
23
 
@@ -40,10 +39,7 @@ cargo build --release
40
  "kiroVersion": "0.8.0",
41
  "machineId": "如果你需要自定义机器码请将64位机器码填到这里", // 不是标准格式会自动忽略, 自动生成
42
  "systemVersion": "darwin#24.6.0",
43
- "nodeVersion": "22.21.1",
44
- "countTokensApiUrl": "https://api.example.com/v1/messages/count_tokens", // 可选,外部 count_tokens API 地址
45
- "countTokensApiKey": "sk-your-count-tokens-api-key", // 可选,外部 API 密钥
46
- "countTokensAuthType": "x-api-key" // 可选,认证类型:x-api-key 或 bearer
47
  }
48
  ```
49
 
@@ -103,9 +99,6 @@ curl http://127.0.0.1:8990/v1/messages \
103
  | `machineId` | string | - | 自定义机器码(64位十六进制)不定义则自动生成 |
104
  | `systemVersion` | string | 随机 | 系统版本标识 |
105
  | `nodeVersion` | string | `22.21.1` | Node.js 版本标识 |
106
- | `countTokensApiUrl` | string | - | 外部 count_tokens API 地址(可选) |
107
- | `countTokensApiKey` | string | - | 外部 count_tokens API 密钥(可选) |
108
- | `countTokensAuthType` | string | `x-api-key` | 外部 API 认证类型:`x-api-key` 或 `bearer` |
109
 
110
  ### credentials.json
111
 
 
17
  |------|------|-------------|
18
  | `/v1/models` | GET | 获取可用模型列表 |
19
  | `/v1/messages` | POST | 创建消息(对话) |
 
20
 
21
  ## 快速开始
22
 
 
39
  "kiroVersion": "0.8.0",
40
  "machineId": "如果你需要自定义机器码请将64位机器码填到这里", // 不是标准格式会自动忽略, 自动生成
41
  "systemVersion": "darwin#24.6.0",
42
+ "nodeVersion": "22.21.1"
 
 
 
43
  }
44
  ```
45
 
 
99
  | `machineId` | string | - | 自定义机器码(64位十六进制)不定义则自动生成 |
100
  | `systemVersion` | string | 随机 | 系统版本标识 |
101
  | `nodeVersion` | string | `22.21.1` | Node.js 版本标识 |
 
 
 
102
 
103
  ### credentials.json
104
 
config.example.json CHANGED
@@ -6,8 +6,5 @@
6
  "kiroVersion": "0.8.0",
7
  "machineId": "如果你需要自定义机器码请将64位机器码填到这里",
8
  "systemVersion": "darwin#24.6.0",
9
- "nodeVersion": "22.21.1",
10
- "countTokensApiUrl": "https://api.example.com/v1/messages/count_tokens",
11
- "countTokensApiKey": "sk-your-count-tokens-api-key",
12
- "countTokensAuthType": "x-api-key"
13
  }
 
6
  "kiroVersion": "0.8.0",
7
  "machineId": "如果你需要自定义机器码请将64位机器码填到这里",
8
  "systemVersion": "darwin#24.6.0",
9
+ "nodeVersion": "22.21.1"
 
 
 
10
  }
src/anthropic/handlers.rs CHANGED
@@ -23,9 +23,7 @@ use crate::kiro::parser::decoder::EventStreamDecoder;
23
  use super::converter::{convert_request, ConversionError};
24
  use super::middleware::AppState;
25
  use super::stream::{SseEvent, StreamContext};
26
- use super::types::{
27
- CountTokensRequest, CountTokensResponse, ErrorResponse, MessagesRequest, Model, ModelsResponse,
28
- };
29
 
30
  /// GET /v1/models
31
  ///
@@ -470,19 +468,3 @@ async fn handle_non_stream_request(
470
 
471
 
472
 
473
- /// POST /v1/messages/count_tokens
474
- ///
475
- /// 计算消息的 token 数量
476
- pub async fn count_tokens(JsonExtractor(payload): JsonExtractor<CountTokensRequest>) -> impl IntoResponse {
477
- tracing::info!(
478
- model = %payload.model,
479
- message_count = %payload.messages.len(),
480
- "Received POST /v1/messages/count_tokens request"
481
- );
482
-
483
- let total_tokens = token::count_all_tokens(payload.model, payload.system, payload.messages, payload.tools) as i32;
484
-
485
- Json(CountTokensResponse {
486
- input_tokens: total_tokens.max(1) as i32,
487
- })
488
- }
 
23
  use super::converter::{convert_request, ConversionError};
24
  use super::middleware::AppState;
25
  use super::stream::{SseEvent, StreamContext};
26
+ use super::types::{ErrorResponse, MessagesRequest, Model, ModelsResponse};
 
 
27
 
28
  /// GET /v1/models
29
  ///
 
468
 
469
 
470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/anthropic/mod.rs CHANGED
@@ -5,7 +5,6 @@
5
  //! # 支持的端点
6
  //! - `GET /v1/models` - 获取可用模型列表
7
  //! - `POST /v1/messages` - 创建消息(对话)
8
- //! - `POST /v1/messages/count_tokens` - 计算 token 数量
9
  //!
10
  //! # 使用示例
11
  //! ```rust,ignore
 
5
  //! # 支持的端点
6
  //! - `GET /v1/models` - 获取可用模型列表
7
  //! - `POST /v1/messages` - 创建消息(对话)
 
8
  //!
9
  //! # 使用示例
10
  //! ```rust,ignore
src/anthropic/router.rs CHANGED
@@ -9,7 +9,7 @@ use axum::{
9
  use crate::kiro::provider::KiroProvider;
10
 
11
  use super::{
12
- handlers::{count_tokens, get_models, post_messages},
13
  middleware::{auth_middleware, cors_layer, AppState},
14
  };
15
 
@@ -18,7 +18,6 @@ use super::{
18
  /// # 端点
19
  /// - `GET /v1/models` - 获取可用模型列表
20
  /// - `POST /v1/messages` - 创建消息(对话)
21
- /// - `POST /v1/messages/count_tokens` - 计算 token 数量
22
  ///
23
  /// # 认证
24
  /// 所有 `/v1` 路径需要 API Key 认证,支持:
@@ -47,7 +46,6 @@ pub fn create_router_with_provider(
47
  let v1_routes = Router::new()
48
  .route("/models", get(get_models))
49
  .route("/messages", post(post_messages))
50
- .route("/messages/count_tokens", post(count_tokens))
51
  .layer(middleware::from_fn_with_state(
52
  state.clone(),
53
  auth_middleware,
 
9
  use crate::kiro::provider::KiroProvider;
10
 
11
  use super::{
12
+ handlers::{get_models, post_messages},
13
  middleware::{auth_middleware, cors_layer, AppState},
14
  };
15
 
 
18
  /// # 端点
19
  /// - `GET /v1/models` - 获取可用模型列表
20
  /// - `POST /v1/messages` - 创建消息(对话)
 
21
  ///
22
  /// # 认证
23
  /// 所有 `/v1` 路径需要 API Key 认证,支持:
 
46
  let v1_routes = Router::new()
47
  .route("/models", get(get_models))
48
  .route("/messages", post(post_messages))
 
49
  .layer(middleware::from_fn_with_state(
50
  state.clone(),
51
  auth_middleware,
src/anthropic/token.rs CHANGED
@@ -7,34 +7,7 @@
7
  //! - 西文字符:每个计 1 个字符单位
8
  //! - 4 个字符单位 = 1 token(四舍五入)
9
 
10
- use std::sync::OnceLock;
11
- use crate::anthropic::types::{CountTokensRequest, CountTokensResponse, Message, SystemMessage, Tool};
12
-
13
- /// Count Tokens API 配置
14
- #[derive(Clone, Default)]
15
- pub struct CountTokensConfig {
16
- /// 外部 count_tokens API 地址
17
- pub api_url: Option<String>,
18
- /// count_tokens API 密钥
19
- pub api_key: Option<String>,
20
- /// count_tokens API 认证类型("x-api-key" 或 "bearer")
21
- pub auth_type: String,
22
- }
23
-
24
- /// 全局配置存储
25
- static COUNT_TOKENS_CONFIG: OnceLock<CountTokensConfig> = OnceLock::new();
26
-
27
- /// 初始化 count_tokens 配置
28
- ///
29
- /// 应在应用启动时调用一次
30
- pub fn init_config(config: CountTokensConfig) {
31
- let _ = COUNT_TOKENS_CONFIG.set(config);
32
- }
33
-
34
- /// 获取配置
35
- fn get_config() -> Option<&'static CountTokensConfig> {
36
- COUNT_TOKENS_CONFIG.get()
37
- }
38
 
39
 
40
  /// 判断字符是否为非西文字符
@@ -98,84 +71,7 @@ pub fn count_tokens(text: &str) -> u64 {
98
 
99
 
100
  /// 估算请求的输入 tokens
101
- ///
102
- /// 优先调用远程 API,失败时回退到本地计算
103
- pub(crate) fn count_all_tokens(model: String, system: Option<Vec<SystemMessage>>, messages: Vec<Message>, tools: Option<Vec<Tool>>) -> u64 {
104
- // 检查是否配置了远程 API
105
- if let Some(config) = get_config() {
106
- if let Some(api_url) = &config.api_url {
107
- // 尝试调用远程 API
108
- let result = tokio::task::block_in_place(|| {
109
- tokio::runtime::Handle::current().block_on(
110
- call_remote_count_tokens(api_url, config, model, &system, &messages, &tools)
111
- )
112
- });
113
-
114
- match result {
115
- Ok(tokens) => {
116
- tracing::debug!("远程 count_tokens API 返回: {}", tokens);
117
- return tokens;
118
- }
119
- Err(e) => {
120
- tracing::warn!("远程 count_tokens API 调用失败,回退到本地计算: {}", e);
121
- }
122
- }
123
- }
124
- }
125
-
126
- // 本地计算
127
- count_all_tokens_local(system, messages, tools)
128
- }
129
-
130
- /// 调用远程 count_tokens API
131
- async fn call_remote_count_tokens(
132
- api_url: &str,
133
- config: &CountTokensConfig,
134
- model: String,
135
- system: &Option<Vec<SystemMessage>>,
136
- messages: &Vec<Message>,
137
- tools: &Option<Vec<Tool>>,
138
- ) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
139
- let client = reqwest::Client::new();
140
-
141
- // 构建请求体
142
- let request = CountTokensRequest {
143
- model: model, // 模型名称用于 token 计算
144
- messages: messages.clone(),
145
- system: system.clone(),
146
- tools: tools.clone(),
147
- };
148
-
149
- // 构建请求
150
- let mut req_builder = client.post(api_url);
151
-
152
- // 设置认证头
153
- if let Some(api_key) = &config.api_key {
154
- if config.auth_type == "bearer" {
155
- req_builder = req_builder.header("Authorization", format!("Bearer {}", api_key));
156
- } else {
157
- req_builder = req_builder.header("x-api-key", api_key);
158
- }
159
- }
160
-
161
- // 发送请求
162
- let response = req_builder
163
- .header("Content-Type", "application/json")
164
- .json(&request)
165
- .timeout(std::time::Duration::from_secs(5))
166
- .send()
167
- .await?;
168
-
169
- if !response.status().is_success() {
170
- return Err(format!("API 返回错误状态: {}", response.status()).into());
171
- }
172
-
173
- let result: CountTokensResponse = response.json().await?;
174
- Ok(result.input_tokens as u64)
175
- }
176
-
177
- /// 本地计算请求的输入 tokens
178
- fn count_all_tokens_local(system: Option<Vec<SystemMessage>>, messages: Vec<Message>, tools: Option<Vec<Tool>>) -> u64 {
179
  let mut total = 0;
180
 
181
  // 系统消息
 
7
  //! - 西文字符:每个计 1 个字符单位
8
  //! - 4 个字符单位 = 1 token(四舍五入)
9
 
10
+ use crate::anthropic::types::{Message, SystemMessage, Tool};
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  /// 判断字符是否为非西文字符
 
71
 
72
 
73
  /// 估算请求的输入 tokens
74
+ pub(crate) fn count_all_tokens(_model: String, system: Option<Vec<SystemMessage>>, messages: Vec<Message>, tools: Option<Vec<Tool>>) -> u64 {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  let mut total = 0;
76
 
77
  // 系统消息
src/anthropic/types.rs CHANGED
@@ -150,22 +150,3 @@ pub struct ImageSource {
150
  pub media_type: String,
151
  pub data: String,
152
  }
153
-
154
- // === Count Tokens 端点类型 ===
155
-
156
- /// Token 计数请求
157
- #[derive(Debug, Serialize, Deserialize)]
158
- pub struct CountTokensRequest {
159
- pub model: String,
160
- pub messages: Vec<Message>,
161
- #[serde(skip_serializing_if = "Option::is_none")]
162
- pub system: Option<Vec<SystemMessage>>,
163
- #[serde(skip_serializing_if = "Option::is_none")]
164
- pub tools: Option<Vec<Tool>>,
165
- }
166
-
167
- /// Token 计数响应
168
- #[derive(Debug, Serialize, Deserialize)]
169
- pub struct CountTokensResponse {
170
- pub input_tokens: i32,
171
- }
 
150
  pub media_type: String,
151
  pub data: String,
152
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/main.rs CHANGED
@@ -48,13 +48,6 @@ async fn main() {
48
  let token_manager = TokenManager::new(config.clone(), credentials.clone());
49
  let kiro_provider = KiroProvider::new(token_manager);
50
 
51
- // 初始化 count_tokens 配置
52
- anthropic::token::init_config(anthropic::token::CountTokensConfig {
53
- api_url: config.count_tokens_api_url.clone(),
54
- api_key: config.count_tokens_api_key.clone(),
55
- auth_type: config.count_tokens_auth_type.clone(),
56
- });
57
-
58
  // 构建路由(从凭据获取 profile_arn)
59
  let app = anthropic::create_router_with_provider(&api_key, Some(kiro_provider), credentials.profile_arn.clone());
60
 
@@ -65,7 +58,6 @@ async fn main() {
65
  tracing::info!("可用 API:");
66
  tracing::info!(" GET /v1/models");
67
  tracing::info!(" POST /v1/messages");
68
- tracing::info!(" POST /v1/messages/count_tokens");
69
 
70
  let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
71
  axum::serve(listener, app).await.unwrap();
 
48
  let token_manager = TokenManager::new(config.clone(), credentials.clone());
49
  let kiro_provider = KiroProvider::new(token_manager);
50
 
 
 
 
 
 
 
 
51
  // 构建路由(从凭据获取 profile_arn)
52
  let app = anthropic::create_router_with_provider(&api_key, Some(kiro_provider), credentials.profile_arn.clone());
53
 
 
58
  tracing::info!("可用 API:");
59
  tracing::info!(" GET /v1/models");
60
  tracing::info!(" POST /v1/messages");
 
61
 
62
  let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
63
  axum::serve(listener, app).await.unwrap();
src/model/config.rs CHANGED
@@ -29,18 +29,6 @@ pub struct Config {
29
 
30
  #[serde(default = "default_node_version")]
31
  pub node_version: String,
32
-
33
- /// 外部 count_tokens API 地址(可选)
34
- #[serde(default)]
35
- pub count_tokens_api_url: Option<String>,
36
-
37
- /// count_tokens API 密钥(可选)
38
- #[serde(default)]
39
- pub count_tokens_api_key: Option<String>,
40
-
41
- /// count_tokens API 认证类型(可选,"x-api-key" 或 "bearer",默认 "x-api-key")
42
- #[serde(default = "default_count_tokens_auth_type")]
43
- pub count_tokens_auth_type: String,
44
  }
45
 
46
  fn default_host() -> String {
@@ -68,10 +56,6 @@ fn default_node_version() -> String {
68
  "22.21.1".to_string()
69
  }
70
 
71
- fn default_count_tokens_auth_type() -> String {
72
- "x-api-key".to_string()
73
- }
74
-
75
  impl Default for Config {
76
  fn default() -> Self {
77
  Self {
@@ -83,9 +67,6 @@ impl Default for Config {
83
  api_key: None,
84
  system_version: default_system_version(),
85
  node_version: default_node_version(),
86
- count_tokens_api_url: None,
87
- count_tokens_api_key: None,
88
- count_tokens_auth_type: default_count_tokens_auth_type(),
89
  }
90
  }
91
  }
 
29
 
30
  #[serde(default = "default_node_version")]
31
  pub node_version: String,
 
 
 
 
 
 
 
 
 
 
 
 
32
  }
33
 
34
  fn default_host() -> String {
 
56
  "22.21.1".to_string()
57
  }
58
 
 
 
 
 
59
  impl Default for Config {
60
  fn default() -> Self {
61
  Self {
 
67
  api_key: None,
68
  system_version: default_system_version(),
69
  node_version: default_node_version(),
 
 
 
70
  }
71
  }
72
  }