hank9999 commited on
Commit ·
256ba56
1
Parent(s): d40d682
feat: 添加远程 count_tokens API 支持并完善相关配置逻辑
Browse files- src/anthropic/token.rs +104 -1
- src/anthropic/types.rs +8 -6
- src/main.rs +7 -0
src/anthropic/token.rs
CHANGED
|
@@ -7,7 +7,34 @@
|
|
| 7 |
//! - 西文字符:每个计 1 个字符单位
|
| 8 |
//! - 4 个字符单位 = 1 token(四舍五入)
|
| 9 |
|
| 10 |
-
use
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
/// 判断字符是否为非西文字符
|
|
@@ -71,7 +98,83 @@ pub fn count_tokens(text: &str) -> u64 {
|
|
| 71 |
|
| 72 |
|
| 73 |
/// 估算请求的输入 tokens
|
|
|
|
|
|
|
| 74 |
pub(crate) fn count_all_tokens(system: Option<Vec<SystemMessage>>, messages: Vec<Message>, tools: Option<Vec<Tool>>) -> u64 {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
let mut total = 0;
|
| 76 |
|
| 77 |
// 系统消息
|
|
|
|
| 7 |
//! - 西文字符:每个计 1 个字符单位
|
| 8 |
//! - 4 个字符单位 = 1 token(四舍五入)
|
| 9 |
|
| 10 |
+
use std::sync::OnceLock;
|
| 11 |
+
use crate::anthropic::types::{CountTokensRequest, CountTokensResponse, Message, SystemMessage, Tool};
|
| 12 |
+
|
| 13 |
+
/// Count Tokens API 配置
|
| 14 |
+
#[derive(Clone, Default)]
|
| 15 |
+
pub struct CountTokensConfig {
|
| 16 |
+
/// 外部 count_tokens API 地址
|
| 17 |
+
pub api_url: Option<String>,
|
| 18 |
+
/// count_tokens API 密钥
|
| 19 |
+
pub api_key: Option<String>,
|
| 20 |
+
/// count_tokens API 认证类型("x-api-key" 或 "bearer")
|
| 21 |
+
pub auth_type: String,
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
/// 全局配置存储
|
| 25 |
+
static COUNT_TOKENS_CONFIG: OnceLock<CountTokensConfig> = OnceLock::new();
|
| 26 |
+
|
| 27 |
+
/// 初始化 count_tokens 配置
|
| 28 |
+
///
|
| 29 |
+
/// 应在应用启动时调用一次
|
| 30 |
+
pub fn init_config(config: CountTokensConfig) {
|
| 31 |
+
let _ = COUNT_TOKENS_CONFIG.set(config);
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
/// 获取配置
|
| 35 |
+
fn get_config() -> Option<&'static CountTokensConfig> {
|
| 36 |
+
COUNT_TOKENS_CONFIG.get()
|
| 37 |
+
}
|
| 38 |
|
| 39 |
|
| 40 |
/// 判断字符是否为非西文字符
|
|
|
|
| 98 |
|
| 99 |
|
| 100 |
/// 估算请求的输入 tokens
|
| 101 |
+
///
|
| 102 |
+
/// 优先调用远程 API,失败时回退到本地计算
|
| 103 |
pub(crate) fn count_all_tokens(system: Option<Vec<SystemMessage>>, messages: Vec<Message>, tools: Option<Vec<Tool>>) -> u64 {
|
| 104 |
+
// 检查是否配置了远程 API
|
| 105 |
+
if let Some(config) = get_config() {
|
| 106 |
+
if let Some(api_url) = &config.api_url {
|
| 107 |
+
// 尝试调用远程 API
|
| 108 |
+
let result = tokio::task::block_in_place(|| {
|
| 109 |
+
tokio::runtime::Handle::current().block_on(
|
| 110 |
+
call_remote_count_tokens(api_url, config, &system, &messages, &tools)
|
| 111 |
+
)
|
| 112 |
+
});
|
| 113 |
+
|
| 114 |
+
match result {
|
| 115 |
+
Ok(tokens) => {
|
| 116 |
+
tracing::debug!("远程 count_tokens API 返回: {}", tokens);
|
| 117 |
+
return tokens;
|
| 118 |
+
}
|
| 119 |
+
Err(e) => {
|
| 120 |
+
tracing::warn!("远程 count_tokens API 调用失败,回退到本地计算: {}", e);
|
| 121 |
+
}
|
| 122 |
+
}
|
| 123 |
+
}
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
// 本地计算
|
| 127 |
+
count_all_tokens_local(system, messages, tools)
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
/// 调用远程 count_tokens API
|
| 131 |
+
async fn call_remote_count_tokens(
|
| 132 |
+
api_url: &str,
|
| 133 |
+
config: &CountTokensConfig,
|
| 134 |
+
system: &Option<Vec<SystemMessage>>,
|
| 135 |
+
messages: &Vec<Message>,
|
| 136 |
+
tools: &Option<Vec<Tool>>,
|
| 137 |
+
) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
|
| 138 |
+
let client = reqwest::Client::new();
|
| 139 |
+
|
| 140 |
+
// 构建请求体
|
| 141 |
+
let request = CountTokensRequest {
|
| 142 |
+
model: "claude-sonnet-4-5-20250929".to_string(), // 模型名称用于 token 计算
|
| 143 |
+
messages: messages.clone(),
|
| 144 |
+
system: system.clone(),
|
| 145 |
+
tools: tools.clone(),
|
| 146 |
+
};
|
| 147 |
+
|
| 148 |
+
// 构建请求
|
| 149 |
+
let mut req_builder = client.post(api_url);
|
| 150 |
+
|
| 151 |
+
// 设置认证头
|
| 152 |
+
if let Some(api_key) = &config.api_key {
|
| 153 |
+
if config.auth_type == "bearer" {
|
| 154 |
+
req_builder = req_builder.header("Authorization", format!("Bearer {}", api_key));
|
| 155 |
+
} else {
|
| 156 |
+
req_builder = req_builder.header("x-api-key", api_key);
|
| 157 |
+
}
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
// 发送请求
|
| 161 |
+
let response = req_builder
|
| 162 |
+
.header("Content-Type", "application/json")
|
| 163 |
+
.json(&request)
|
| 164 |
+
.timeout(std::time::Duration::from_secs(5))
|
| 165 |
+
.send()
|
| 166 |
+
.await?;
|
| 167 |
+
|
| 168 |
+
if !response.status().is_success() {
|
| 169 |
+
return Err(format!("API 返回错误状态: {}", response.status()).into());
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
let result: CountTokensResponse = response.json().await?;
|
| 173 |
+
Ok(result.input_tokens as u64)
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
/// 本地计算请求的输入 tokens
|
| 177 |
+
fn count_all_tokens_local(system: Option<Vec<SystemMessage>>, messages: Vec<Message>, tools: Option<Vec<Tool>>) -> u64 {
|
| 178 |
let mut total = 0;
|
| 179 |
|
| 180 |
// 系统消息
|
src/anthropic/types.rs
CHANGED
|
@@ -98,7 +98,7 @@ pub struct MessagesRequest {
|
|
| 98 |
}
|
| 99 |
|
| 100 |
/// 消息
|
| 101 |
-
#[derive(Debug, Deserialize)]
|
| 102 |
pub struct Message {
|
| 103 |
pub role: String,
|
| 104 |
/// 可以是 string 或 ContentBlock 数组
|
|
@@ -106,13 +106,13 @@ pub struct Message {
|
|
| 106 |
}
|
| 107 |
|
| 108 |
/// 系统消息
|
| 109 |
-
#[derive(Debug, Deserialize)]
|
| 110 |
pub struct SystemMessage {
|
| 111 |
-
pub text: String
|
| 112 |
}
|
| 113 |
|
| 114 |
/// 工具定义
|
| 115 |
-
#[derive(Debug, Deserialize)]
|
| 116 |
pub struct Tool {
|
| 117 |
pub name: String,
|
| 118 |
pub description: String,
|
|
@@ -154,16 +154,18 @@ pub struct ImageSource {
|
|
| 154 |
// === Count Tokens 端点类型 ===
|
| 155 |
|
| 156 |
/// Token 计数请求
|
| 157 |
-
#[derive(Debug, Deserialize)]
|
| 158 |
pub struct CountTokensRequest {
|
| 159 |
pub model: String,
|
| 160 |
pub messages: Vec<Message>,
|
|
|
|
| 161 |
pub system: Option<Vec<SystemMessage>>,
|
|
|
|
| 162 |
pub tools: Option<Vec<Tool>>,
|
| 163 |
}
|
| 164 |
|
| 165 |
/// Token 计数响应
|
| 166 |
-
#[derive(Debug, Serialize)]
|
| 167 |
pub struct CountTokensResponse {
|
| 168 |
pub input_tokens: i32,
|
| 169 |
}
|
|
|
|
| 98 |
}
|
| 99 |
|
| 100 |
/// 消息
|
| 101 |
+
#[derive(Debug, Clone, Deserialize, Serialize)]
|
| 102 |
pub struct Message {
|
| 103 |
pub role: String,
|
| 104 |
/// 可以是 string 或 ContentBlock 数组
|
|
|
|
| 106 |
}
|
| 107 |
|
| 108 |
/// 系统消息
|
| 109 |
+
#[derive(Debug, Clone, Deserialize, Serialize)]
|
| 110 |
pub struct SystemMessage {
|
| 111 |
+
pub text: String,
|
| 112 |
}
|
| 113 |
|
| 114 |
/// 工具定义
|
| 115 |
+
#[derive(Debug, Clone, Deserialize, Serialize)]
|
| 116 |
pub struct Tool {
|
| 117 |
pub name: String,
|
| 118 |
pub description: String,
|
|
|
|
| 154 |
// === Count Tokens 端点类型 ===
|
| 155 |
|
| 156 |
/// Token 计数请求
|
| 157 |
+
#[derive(Debug, Serialize, Deserialize)]
|
| 158 |
pub struct CountTokensRequest {
|
| 159 |
pub model: String,
|
| 160 |
pub messages: Vec<Message>,
|
| 161 |
+
#[serde(skip_serializing_if = "Option::is_none")]
|
| 162 |
pub system: Option<Vec<SystemMessage>>,
|
| 163 |
+
#[serde(skip_serializing_if = "Option::is_none")]
|
| 164 |
pub tools: Option<Vec<Tool>>,
|
| 165 |
}
|
| 166 |
|
| 167 |
/// Token 计数响应
|
| 168 |
+
#[derive(Debug, Serialize, Deserialize)]
|
| 169 |
pub struct CountTokensResponse {
|
| 170 |
pub input_tokens: i32,
|
| 171 |
}
|
src/main.rs
CHANGED
|
@@ -48,6 +48,13 @@ async fn main() {
|
|
| 48 |
let token_manager = TokenManager::new(config.clone(), credentials.clone());
|
| 49 |
let kiro_provider = KiroProvider::new(token_manager);
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
// 构建路由(从凭据获取 profile_arn)
|
| 52 |
let app = anthropic::create_router_with_provider(&api_key, Some(kiro_provider), credentials.profile_arn.clone());
|
| 53 |
|
|
|
|
| 48 |
let token_manager = TokenManager::new(config.clone(), credentials.clone());
|
| 49 |
let kiro_provider = KiroProvider::new(token_manager);
|
| 50 |
|
| 51 |
+
// 初始化 count_tokens 配置
|
| 52 |
+
anthropic::token::init_config(anthropic::token::CountTokensConfig {
|
| 53 |
+
api_url: config.count_tokens_api_url.clone(),
|
| 54 |
+
api_key: config.count_tokens_api_key.clone(),
|
| 55 |
+
auth_type: config.count_tokens_auth_type.clone(),
|
| 56 |
+
});
|
| 57 |
+
|
| 58 |
// 构建路由(从凭据获取 profile_arn)
|
| 59 |
let app = anthropic::create_router_with_provider(&api_key, Some(kiro_provider), credentials.profile_arn.clone());
|
| 60 |
|