hank9999 commited on
Commit ·
f0d4864
1
Parent(s): 256ba56
fix: 为 count_all_tokens 增加 model 参数,支持动态指定模型
Browse files- src/anthropic/handlers.rs +2 -2
- src/anthropic/token.rs +4 -3
src/anthropic/handlers.rs
CHANGED
|
@@ -143,7 +143,7 @@ pub async fn post_messages(
|
|
| 143 |
tracing::debug!("Kiro request body: {}", request_body);
|
| 144 |
|
| 145 |
// 估算输入 tokens
|
| 146 |
-
let input_tokens = token::count_all_tokens(payload.system, payload.messages, payload.tools) as i32;
|
| 147 |
|
| 148 |
// 检查是否启用了thinking
|
| 149 |
let thinking_enabled = payload.thinking
|
|
@@ -440,7 +440,7 @@ pub async fn count_tokens(JsonExtractor(payload): JsonExtractor<CountTokensReque
|
|
| 440 |
"Received POST /v1/messages/count_tokens request"
|
| 441 |
);
|
| 442 |
|
| 443 |
-
let total_tokens = token::count_all_tokens(payload.system, payload.messages, payload.tools) as i32;
|
| 444 |
|
| 445 |
Json(CountTokensResponse {
|
| 446 |
input_tokens: total_tokens.max(1) as i32,
|
|
|
|
| 143 |
tracing::debug!("Kiro request body: {}", request_body);
|
| 144 |
|
| 145 |
// 估算输入 tokens
|
| 146 |
+
let input_tokens = token::count_all_tokens(payload.model.clone(), payload.system, payload.messages, payload.tools) as i32;
|
| 147 |
|
| 148 |
// 检查是否启用了thinking
|
| 149 |
let thinking_enabled = payload.thinking
|
|
|
|
| 440 |
"Received POST /v1/messages/count_tokens request"
|
| 441 |
);
|
| 442 |
|
| 443 |
+
let total_tokens = token::count_all_tokens(payload.model, payload.system, payload.messages, payload.tools) as i32;
|
| 444 |
|
| 445 |
Json(CountTokensResponse {
|
| 446 |
input_tokens: total_tokens.max(1) as i32,
|
src/anthropic/token.rs
CHANGED
|
@@ -100,14 +100,14 @@ pub fn count_tokens(text: &str) -> u64 {
|
|
| 100 |
/// 估算请求的输入 tokens
|
| 101 |
///
|
| 102 |
/// 优先调用远程 API,失败时回退到本地计算
|
| 103 |
-
pub(crate) fn count_all_tokens(system: Option<Vec<SystemMessage>>, messages: Vec<Message>, tools: Option<Vec<Tool>>) -> u64 {
|
| 104 |
// 检查是否配置了远程 API
|
| 105 |
if let Some(config) = get_config() {
|
| 106 |
if let Some(api_url) = &config.api_url {
|
| 107 |
// 尝试调用远程 API
|
| 108 |
let result = tokio::task::block_in_place(|| {
|
| 109 |
tokio::runtime::Handle::current().block_on(
|
| 110 |
-
call_remote_count_tokens(api_url, config, &system, &messages, &tools)
|
| 111 |
)
|
| 112 |
});
|
| 113 |
|
|
@@ -131,6 +131,7 @@ pub(crate) fn count_all_tokens(system: Option<Vec<SystemMessage>>, messages: Vec
|
|
| 131 |
async fn call_remote_count_tokens(
|
| 132 |
api_url: &str,
|
| 133 |
config: &CountTokensConfig,
|
|
|
|
| 134 |
system: &Option<Vec<SystemMessage>>,
|
| 135 |
messages: &Vec<Message>,
|
| 136 |
tools: &Option<Vec<Tool>>,
|
|
@@ -139,7 +140,7 @@ async fn call_remote_count_tokens(
|
|
| 139 |
|
| 140 |
// 构建请求体
|
| 141 |
let request = CountTokensRequest {
|
| 142 |
-
model:
|
| 143 |
messages: messages.clone(),
|
| 144 |
system: system.clone(),
|
| 145 |
tools: tools.clone(),
|
|
|
|
| 100 |
/// 估算请求的输入 tokens
|
| 101 |
///
|
| 102 |
/// 优先调用远程 API,失败时回退到本地计算
|
| 103 |
+
pub(crate) fn count_all_tokens(model: String, system: Option<Vec<SystemMessage>>, messages: Vec<Message>, tools: Option<Vec<Tool>>) -> u64 {
|
| 104 |
// 检查是否配置了远程 API
|
| 105 |
if let Some(config) = get_config() {
|
| 106 |
if let Some(api_url) = &config.api_url {
|
| 107 |
// 尝试调用远程 API
|
| 108 |
let result = tokio::task::block_in_place(|| {
|
| 109 |
tokio::runtime::Handle::current().block_on(
|
| 110 |
+
call_remote_count_tokens(api_url, config, model, &system, &messages, &tools)
|
| 111 |
)
|
| 112 |
});
|
| 113 |
|
|
|
|
| 131 |
async fn call_remote_count_tokens(
|
| 132 |
api_url: &str,
|
| 133 |
config: &CountTokensConfig,
|
| 134 |
+
model: String,
|
| 135 |
system: &Option<Vec<SystemMessage>>,
|
| 136 |
messages: &Vec<Message>,
|
| 137 |
tools: &Option<Vec<Tool>>,
|
|
|
|
| 140 |
|
| 141 |
// 构建请求体
|
| 142 |
let request = CountTokensRequest {
|
| 143 |
+
model: model, // 模型名称用于 token 计算
|
| 144 |
messages: messages.clone(),
|
| 145 |
system: system.clone(),
|
| 146 |
tools: tools.clone(),
|