# src/core/auth_token_manager.py import time import json import threading from src.core.logger import logger # 从新的位置导入 logger class AuthTokenManager: def __init__(self): self.token_model_map = {} self.expired_tokens = set() self.token_status_map = {} self.model_config = { "grok-2": { "RequestFrequency": 30, "ExpirationTime": 1 * 60 * 60 * 1000 # 1小时 }, "grok-3": { "RequestFrequency": 20, "ExpirationTime": 2 * 60 * 60 * 1000 # 2小时 }, "grok-3-deepsearch": { "RequestFrequency": 10, "ExpirationTime": 24 * 60 * 60 * 1000 # 24小时 }, "grok-3-reasoning": { "RequestFrequency": 10, "ExpirationTime": 24 * 60 * 60 * 1000 # 24小时 } } self.token_reset_switch = False self.token_reset_timer = None def add_token(self, token): sso = token.split("sso=")[1].split(";")[0] for model in self.model_config.keys(): if model not in self.token_model_map: self.token_model_map[model] = [] if sso not in self.token_status_map: self.token_status_map[sso] = {} existing_token_entry = next((entry for entry in self.token_model_map[model] if entry["token"] == token), None) if not existing_token_entry: self.token_model_map[model].append({ "token": token, "RequestCount": 0, "AddedTime": int(time.time() * 1000), "StartCallTime": None }) if model not in self.token_status_map[sso]: self.token_status_map[sso][model] = { "isValid": True, "invalidatedTime": None, "totalRequestCount": 0 } def set_token(self, token): models = list(self.model_config.keys()) self.token_model_map = {model: [{ "token": token, "RequestCount": 0, "AddedTime": int(time.time() * 1000), "StartCallTime": None }] for model in models} sso = token.split("sso=")[1].split(";")[0] self.token_status_map[sso] = {model: { "isValid": True, "invalidatedTime": None, "totalRequestCount": 0 } for model in models} def delete_token(self, token): try: sso = token.split("sso=")[1].split(";")[0] for model in self.token_model_map: self.token_model_map[model] = [entry for entry in self.token_model_map[model] if entry["token"] != token] if sso in self.token_status_map: del self.token_status_map[sso] logger.info(f"令牌已成功移除: {token}", "TokenManager") return True except Exception as error: logger.error(f"令牌删除失败: {str(error)}") return False def reduce_token_request_count(self, model_id, count): try: normalized_model = self.normalize_model_name(model_id) if normalized_model not in self.token_model_map: logger.error(f"模型 {normalized_model} 不存在", "TokenManager") return False if not self.token_model_map[normalized_model]: logger.error(f"模型 {normalized_model} 没有可用的token", "TokenManager") return False token_entry = self.token_model_map[normalized_model][0] # 确保RequestCount不会小于0 new_count = max(0, token_entry["RequestCount"] - count) reduction = token_entry["RequestCount"] - new_count token_entry["RequestCount"] = new_count # 更新token状态 if token_entry["token"]: sso = token_entry["token"].split("sso=")[1].split(";")[0] if sso in self.token_status_map and normalized_model in self.token_status_map[sso]: self.token_status_map[sso][normalized_model]["totalRequestCount"] = max( 0, self.token_status_map[sso][normalized_model]["totalRequestCount"] - reduction ) return True except Exception as error: logger.error(f"重置校对token请求次数时发生错误: {str(error)}", "TokenManager") return False def get_next_token_for_model(self, model_id, is_return=False): normalized_model = self.normalize_model_name(model_id) if normalized_model not in self.token_model_map or not self.token_model_map[normalized_model]: return None token_entry = self.token_model_map[normalized_model][0] if is_return: return token_entry["token"] if token_entry: if token_entry["StartCallTime"] is None: token_entry["StartCallTime"] = int(time.time() * 1000) if not self.token_reset_switch: self.start_token_reset_process() self.token_reset_switch = True token_entry["RequestCount"] += 1 if token_entry["RequestCount"] > self.model_config[normalized_model]["RequestFrequency"]: self.remove_token_from_model(normalized_model, token_entry["token"]) next_token_entry = self.token_model_map[normalized_model][0] if self.token_model_map[normalized_model] else None return next_token_entry["token"] if next_token_entry else None sso = token_entry["token"].split("sso=")[1].split(";")[0] if sso in self.token_status_map and normalized_model in self.token_status_map[sso]: if token_entry["RequestCount"] == self.model_config[normalized_model]["RequestFrequency"]: self.token_status_map[sso][normalized_model]["isValid"] = False self.token_status_map[sso][normalized_model]["invalidatedTime"] = int(time.time() * 1000) self.token_status_map[sso][normalized_model]["totalRequestCount"] += 1 return token_entry["token"] return None def remove_token_from_model(self, model_id, token): normalized_model = self.normalize_model_name(model_id) if normalized_model not in self.token_model_map: logger.error(f"模型 {normalized_model} 不存在", "TokenManager") return False model_tokens = self.token_model_map[normalized_model] token_index = next((i for i, entry in enumerate(model_tokens) if entry["token"] == token), -1) if token_index != -1: removed_token_entry = model_tokens.pop(token_index) self.expired_tokens.add(( removed_token_entry["token"], normalized_model, int(time.time() * 1000) )) if not self.token_reset_switch: self.start_token_reset_process() self.token_reset_switch = True logger.info(f"模型{model_id}的令牌已失效,已成功移除令牌: {token}", "TokenManager") return True logger.error(f"在模型 {normalized_model} 中未找到 token: {token}", "TokenManager") return False def get_expired_tokens(self): return list(self.expired_tokens) def normalize_model_name(self, model): if model.startswith('grok-') and 'deepsearch' not in model and 'reasoning' not in model: return '-'.join(model.split('-')[:2]) return model def get_token_count_for_model(self, model_id): normalized_model = self.normalize_model_name(model_id) return len(self.token_model_map.get(normalized_model, [])) def get_remaining_token_request_capacity(self): remaining_capacity_map = {} for model in self.model_config.keys(): model_tokens = self.token_model_map.get(model, []) model_request_frequency = self.model_config[model]["RequestFrequency"] total_used_requests = sum(token_entry.get("RequestCount", 0) for token_entry in model_tokens) remaining_capacity = (len(model_tokens) * model_request_frequency) - total_used_requests remaining_capacity_map[model] = max(0, remaining_capacity) return remaining_capacity_map def get_token_array_for_model(self, model_id): normalized_model = self.normalize_model_name(model_id) return self.token_model_map.get(normalized_model, []) def start_token_reset_process(self): def reset_expired_tokens(): now = int(time.time() * 1000) tokens_to_remove = set() for token_info in self.expired_tokens: token, model, expired_time = token_info expiration_time = self.model_config[model]["ExpirationTime"] if now - expired_time >= expiration_time: if not any(entry["token"] == token for entry in self.token_model_map.get(model, [])): if model not in self.token_model_map: self.token_model_map[model] = [] self.token_model_map[model].append({ "token": token, "RequestCount": 0, "AddedTime": now, "StartCallTime": None }) sso = token.split("sso=")[1].split(";")[0] if sso in self.token_status_map and model in self.token_status_map[sso]: self.token_status_map[sso][model]["isValid"] = True self.token_status_map[sso][model]["invalidatedTime"] = None self.token_status_map[sso][model]["totalRequestCount"] = 0 tokens_to_remove.add(token_info) self.expired_tokens -= tokens_to_remove for model in self.model_config.keys(): if model not in self.token_model_map: continue for token_entry in self.token_model_map[model]: if not token_entry.get("StartCallTime"): continue expiration_time = self.model_config[model]["ExpirationTime"] if now - token_entry["StartCallTime"] >= expiration_time: sso = token_entry["token"].split("sso=")[1].split(";")[0] if sso in self.token_status_map and model in self.token_status_map[sso]: self.token_status_map[sso][model]["isValid"] = True self.token_status_map[sso][model]["invalidatedTime"] = None self.token_status_map[sso][model]["totalRequestCount"] = 0 token_entry["RequestCount"] = 0 token_entry["StartCallTime"] = None # 启动一个线程执行定时任务,每小时执行一次 def run_timer(): while True: reset_expired_tokens() time.sleep(3600) timer_thread = threading.Thread(target=run_timer) timer_thread.daemon = True timer_thread.start() def get_all_tokens(self): all_tokens = set() for model_tokens in self.token_model_map.values(): for entry in model_tokens: all_tokens.add(entry["token"]) return list(all_tokens) def get_current_token(self, model_id): normalized_model = self.normalize_model_name(model_id) if normalized_model not in self.token_model_map or not self.token_model_map[normalized_model]: return None token_entry = self.token_model_map[normalized_model][0] return token_entry["token"] def get_total_token_count(self): """获取总令牌数""" return len(self.get_all_tokens()) def get_total_request_count(self): """获取总请求次数""" total_count = 0 for sso_status in self.token_status_map.values(): for model_status in sso_status.values(): total_count += model_status.get("totalRequestCount", 0) return total_count def get_token_status_map(self): """获取令牌状态映射""" return self.token_status_map