Spaces:
Sleeping
Sleeping
| // 移除冗余的顶层导入,因为这些在代码中已由 full path 或局部导入处理 | |
| use dashmap::DashMap; | |
| use std::path::PathBuf; | |
| use std::sync::atomic::{AtomicUsize, Ordering}; | |
| use std::sync::Arc; | |
| pub struct ProxyToken { | |
| pub account_id: String, | |
| pub access_token: String, | |
| pub refresh_token: String, | |
| pub expires_in: i64, | |
| pub timestamp: i64, | |
| pub email: String, | |
| pub account_path: PathBuf, // 账号文件路径,用于更新 | |
| pub project_id: Option<String>, | |
| } | |
| pub struct TokenManager { | |
| tokens: Arc<DashMap<String, ProxyToken>>, // account_id -> ProxyToken | |
| current_index: Arc<AtomicUsize>, | |
| last_used_account: Arc<tokio::sync::Mutex<Option<(String, std::time::Instant)>>>, | |
| data_dir: PathBuf, | |
| } | |
| impl TokenManager { | |
| /// 创建新的 TokenManager | |
| pub fn new(data_dir: PathBuf) -> Self { | |
| Self { | |
| tokens: Arc::new(DashMap::new()), | |
| current_index: Arc::new(AtomicUsize::new(0)), | |
| last_used_account: Arc::new(tokio::sync::Mutex::new(None)), | |
| data_dir, | |
| } | |
| } | |
| /// 从主应用账号目录加载所有账号 | |
| pub async fn load_accounts(&self) -> Result<usize, String> { | |
| let accounts_dir = self.data_dir.join("accounts"); | |
| if !accounts_dir.exists() { | |
| return Err(format!("账号目录不存在: {:?}", accounts_dir)); | |
| } | |
| let entries = std::fs::read_dir(&accounts_dir) | |
| .map_err(|e| format!("读取账号目录失败: {}", e))?; | |
| let mut count = 0; | |
| for entry in entries { | |
| let entry = entry.map_err(|e| format!("读取目录项失败: {}", e))?; | |
| let path = entry.path(); | |
| if path.extension().and_then(|s| s.to_str()) != Some("json") { | |
| continue; | |
| } | |
| // 尝试加载账号 | |
| match self.load_single_account(&path).await { | |
| Ok(Some(token)) => { | |
| let account_id = token.account_id.clone(); | |
| self.tokens.insert(account_id, token); | |
| count += 1; | |
| }, | |
| Ok(None) => { | |
| // 跳过无效账号 | |
| }, | |
| Err(e) => { | |
| tracing::warn!("加载账号失败 {:?}: {}", path, e); | |
| } | |
| } | |
| } | |
| Ok(count) | |
| } | |
| /// 加载单个账号 | |
| async fn load_single_account(&self, path: &PathBuf) -> Result<Option<ProxyToken>, String> { | |
| let content = std::fs::read_to_string(path) | |
| .map_err(|e| format!("读取文件失败: {}", e))?; | |
| let account: serde_json::Value = serde_json::from_str(&content) | |
| .map_err(|e| format!("解析 JSON 失败: {}", e))?; | |
| let account_id = account["id"].as_str() | |
| .ok_or("缺少 id 字段")? | |
| .to_string(); | |
| let email = account["email"].as_str() | |
| .ok_or("缺少 email 字段")? | |
| .to_string(); | |
| let token_obj = account["token"].as_object() | |
| .ok_or("缺少 token 字段")?; | |
| let access_token = token_obj["access_token"].as_str() | |
| .ok_or("缺少 access_token")? | |
| .to_string(); | |
| let refresh_token = token_obj["refresh_token"].as_str() | |
| .ok_or("缺少 refresh_token")? | |
| .to_string(); | |
| let expires_in = token_obj["expires_in"].as_i64() | |
| .ok_or("缺少 expires_in")?; | |
| let timestamp = token_obj["expiry_timestamp"].as_i64() | |
| .ok_or("缺少 expiry_timestamp")?; | |
| // project_id 是可选的 | |
| let project_id = token_obj.get("project_id") | |
| .and_then(|v| v.as_str()) | |
| .map(|s| s.to_string()); | |
| Ok(Some(ProxyToken { | |
| account_id, | |
| access_token, | |
| refresh_token, | |
| expires_in, | |
| timestamp, | |
| email, | |
| account_path: path.clone(), | |
| project_id, | |
| })) | |
| } | |
| /// 获取当前可用的 Token(带 60s 时间窗口锁定机制) | |
| /// 参数 `_quota_group` 用于区分 "claude" vs "gemini" 组 | |
| /// 参数 `force_rotate` 为 true 时将忽略锁定,强制切换账号 | |
| pub async fn get_token(&self, quota_group: &str, force_rotate: bool) -> Result<(String, String, String), String> { | |
| let total = self.tokens.len(); | |
| if total == 0 { | |
| return Err("Token pool is empty".to_string()); | |
| } | |
| // 1. 检查时间窗口锁定 (60秒内强制复用上一个账号) | |
| // 优化策略: 画图请求 (image_gen) 默认不锁定,以最大化并发能力 | |
| let mut target_token = None; | |
| if !force_rotate && quota_group != "image_gen" { | |
| let last_used = self.last_used_account.lock().await; | |
| if let Some((account_id, last_time)) = &*last_used { | |
| if last_time.elapsed().as_secs() < 60 { | |
| if let Some(entry) = self.tokens.get(account_id) { | |
| tracing::info!("60s 时间窗口内,强制复用上一个账号: {}", entry.email); | |
| target_token = Some(entry.value().clone()); | |
| } | |
| } | |
| } | |
| } | |
| // 2. 如果没有锁定、锁定失效或强制轮换,则进行轮询记录并更新锁定信息 | |
| let mut token = if let Some(t) = target_token { | |
| t | |
| } else { | |
| // 简单轮换策略 (Round Robin) | |
| let idx = self.current_index.fetch_add(1, Ordering::SeqCst) % total; | |
| let selected_token = self.tokens.iter() | |
| .nth(idx) | |
| .map(|entry| entry.value().clone()) | |
| .ok_or("Failed to retrieve token from pool")?; | |
| // 更新最后使用的账号及时间 (如果是普通对话请求) | |
| if quota_group != "image_gen" { | |
| let mut last_used = self.last_used_account.lock().await; | |
| *last_used = Some((selected_token.account_id.clone(), std::time::Instant::now())); | |
| } | |
| let action_msg = if force_rotate { "强制切换" } else { "切换" }; | |
| tracing::info!("{}到账号: {}", action_msg, selected_token.email); | |
| selected_token | |
| }; | |
| // 3. 检查 token 是否过期(提前5分钟刷新) | |
| let now = chrono::Utc::now().timestamp(); | |
| if now >= token.timestamp - 300 { | |
| tracing::info!("账号 {} 的 token 即将过期,正在刷新...", token.email); | |
| // 调用 OAuth 刷新 token | |
| match crate::modules::oauth::refresh_access_token(&token.refresh_token).await { | |
| Ok(token_response) => { | |
| tracing::info!("Token 刷新成功!"); | |
| // 更新本地内存对象供后续使用 | |
| token.access_token = token_response.access_token.clone(); | |
| token.expires_in = token_response.expires_in; | |
| token.timestamp = now + token_response.expires_in; | |
| // 同步更新跨线程共享的 DashMap | |
| if let Some(mut entry) = self.tokens.get_mut(&token.account_id) { | |
| entry.access_token = token.access_token.clone(); | |
| entry.expires_in = token.expires_in; | |
| entry.timestamp = token.timestamp; | |
| } | |
| } | |
| Err(e) => { | |
| tracing::error!("Token 刷新失败: {},尝试下一个账号", e); | |
| return Err(format!("Token refresh failed: {}", e)); | |
| } | |
| } | |
| } | |
| // 4. 确保有 project_id | |
| let project_id = if let Some(pid) = &token.project_id { | |
| pid.clone() | |
| } else { | |
| tracing::info!("账号 {} 缺少 project_id,尝试获取...", token.email); | |
| match crate::proxy::project_resolver::fetch_project_id(&token.access_token).await { | |
| Ok(pid) => { | |
| if let Some(mut entry) = self.tokens.get_mut(&token.account_id) { | |
| entry.project_id = Some(pid.clone()); | |
| } | |
| let _ = self.save_project_id(&token.account_id, &pid).await; | |
| pid | |
| } | |
| Err(e) => { | |
| tracing::error!("Failed to fetch project_id for {}: {}", token.email, e); | |
| return Err(format!("Failed to fetch project_id: {}", e)); | |
| } | |
| } | |
| }; | |
| Ok((token.access_token, project_id, token.email)) | |
| } | |
| /// 保存 project_id 到账号文件 | |
| async fn save_project_id(&self, account_id: &str, project_id: &str) -> Result<(), String> { | |
| let entry = self.tokens.get(account_id) | |
| .ok_or("账号不存在")?; | |
| let path = &entry.account_path; | |
| let mut content: serde_json::Value = serde_json::from_str( | |
| &std::fs::read_to_string(path).map_err(|e| format!("读取文件失败: {}", e))? | |
| ).map_err(|e| format!("解析 JSON 失败: {}", e))?; | |
| content["token"]["project_id"] = serde_json::Value::String(project_id.to_string()); | |
| std::fs::write(path, serde_json::to_string_pretty(&content).unwrap()) | |
| .map_err(|e| format!("写入文件失败: {}", e))?; | |
| tracing::info!("已保存 project_id 到账号 {}", account_id); | |
| Ok(()) | |
| } | |
| /// 保存刷新后的 token 到账号文件 | |
| async fn save_refreshed_token(&self, account_id: &str, token_response: &crate::modules::oauth::TokenResponse) -> Result<(), String> { | |
| let entry = self.tokens.get(account_id) | |
| .ok_or("账号不存在")?; | |
| let path = &entry.account_path; | |
| let mut content: serde_json::Value = serde_json::from_str( | |
| &std::fs::read_to_string(path).map_err(|e| format!("读取文件失败: {}", e))? | |
| ).map_err(|e| format!("解析 JSON 失败: {}", e))?; | |
| let now = chrono::Utc::now().timestamp(); | |
| content["token"]["access_token"] = serde_json::Value::String(token_response.access_token.clone()); | |
| content["token"]["expires_in"] = serde_json::Value::Number(token_response.expires_in.into()); | |
| content["token"]["expiry_timestamp"] = serde_json::Value::Number((now + token_response.expires_in).into()); | |
| std::fs::write(path, serde_json::to_string_pretty(&content).unwrap()) | |
| .map_err(|e| format!("写入文件失败: {}", e))?; | |
| tracing::info!("已保存刷新后的 token 到账号 {}", account_id); | |
| Ok(()) | |
| } | |
| pub fn len(&self) -> usize { | |
| self.tokens.len() | |
| } | |
| /// Get token count (alias for len) | |
| pub async fn token_count(&self) -> usize { | |
| self.tokens.len() | |
| } | |
| /// Add a token manually | |
| pub async fn add_token( | |
| &self, | |
| access_token: String, | |
| refresh_token: String, | |
| expiry_timestamp: i64, | |
| email: String, | |
| project_id: Option<String>, | |
| ) { | |
| let account_id = uuid::Uuid::new_v4().to_string(); | |
| let now = chrono::Utc::now().timestamp(); | |
| let token = ProxyToken { | |
| account_id: account_id.clone(), | |
| access_token, | |
| refresh_token, | |
| expires_in: expiry_timestamp - now, | |
| timestamp: expiry_timestamp, | |
| email, | |
| account_path: self.data_dir.join("accounts").join(format!("{}.json", account_id)), | |
| project_id, | |
| }; | |
| self.tokens.insert(account_id, token); | |
| } | |
| } | |