Spaces:
Paused
Paused
| # src/core/auth_token_manager.py | |
| import time | |
| import json | |
| import threading | |
| from src.core.logger import logger # 从新的位置导入 logger | |
| class AuthTokenManager: | |
| def __init__(self): | |
| self.token_model_map = {} | |
| self.expired_tokens = set() | |
| self.token_status_map = {} | |
| self.model_config = { | |
| "grok-2": { | |
| "RequestFrequency": 30, | |
| "ExpirationTime": 1 * 60 * 60 * 1000 # 1小时 | |
| }, | |
| "grok-3": { | |
| "RequestFrequency": 20, | |
| "ExpirationTime": 2 * 60 * 60 * 1000 # 2小时 | |
| }, | |
| "grok-3-deepsearch": { | |
| "RequestFrequency": 10, | |
| "ExpirationTime": 24 * 60 * 60 * 1000 # 24小时 | |
| }, | |
| "grok-3-reasoning": { | |
| "RequestFrequency": 10, | |
| "ExpirationTime": 24 * 60 * 60 * 1000 # 24小时 | |
| } | |
| } | |
| self.token_reset_switch = False | |
| self.token_reset_timer = None | |
| def add_token(self, token): | |
| sso = token.split("sso=")[1].split(";")[0] | |
| for model in self.model_config.keys(): | |
| if model not in self.token_model_map: | |
| self.token_model_map[model] = [] | |
| if sso not in self.token_status_map: | |
| self.token_status_map[sso] = {} | |
| existing_token_entry = next((entry for entry in self.token_model_map[model] if entry["token"] == token), None) | |
| if not existing_token_entry: | |
| self.token_model_map[model].append({ | |
| "token": token, | |
| "RequestCount": 0, | |
| "AddedTime": int(time.time() * 1000), | |
| "StartCallTime": None | |
| }) | |
| if model not in self.token_status_map[sso]: | |
| self.token_status_map[sso][model] = { | |
| "isValid": True, | |
| "invalidatedTime": None, | |
| "totalRequestCount": 0 | |
| } | |
| def set_token(self, token): | |
| models = list(self.model_config.keys()) | |
| self.token_model_map = {model: [{ | |
| "token": token, | |
| "RequestCount": 0, | |
| "AddedTime": int(time.time() * 1000), | |
| "StartCallTime": None | |
| }] for model in models} | |
| sso = token.split("sso=")[1].split(";")[0] | |
| self.token_status_map[sso] = {model: { | |
| "isValid": True, | |
| "invalidatedTime": None, | |
| "totalRequestCount": 0 | |
| } for model in models} | |
| def delete_token(self, token): | |
| try: | |
| sso = token.split("sso=")[1].split(";")[0] | |
| for model in self.token_model_map: | |
| self.token_model_map[model] = [entry for entry in self.token_model_map[model] if entry["token"] != token] | |
| if sso in self.token_status_map: | |
| del self.token_status_map[sso] | |
| logger.info(f"令牌已成功移除: {token}", "TokenManager") | |
| return True | |
| except Exception as error: | |
| logger.error(f"令牌删除失败: {str(error)}") | |
| return False | |
| def reduce_token_request_count(self, model_id, count): | |
| try: | |
| normalized_model = self.normalize_model_name(model_id) | |
| if normalized_model not in self.token_model_map: | |
| logger.error(f"模型 {normalized_model} 不存在", "TokenManager") | |
| return False | |
| if not self.token_model_map[normalized_model]: | |
| logger.error(f"模型 {normalized_model} 没有可用的token", "TokenManager") | |
| return False | |
| token_entry = self.token_model_map[normalized_model][0] | |
| # 确保RequestCount不会小于0 | |
| new_count = max(0, token_entry["RequestCount"] - count) | |
| reduction = token_entry["RequestCount"] - new_count | |
| token_entry["RequestCount"] = new_count | |
| # 更新token状态 | |
| if token_entry["token"]: | |
| sso = token_entry["token"].split("sso=")[1].split(";")[0] | |
| if sso in self.token_status_map and normalized_model in self.token_status_map[sso]: | |
| self.token_status_map[sso][normalized_model]["totalRequestCount"] = max( | |
| 0, self.token_status_map[sso][normalized_model]["totalRequestCount"] - reduction | |
| ) | |
| return True | |
| except Exception as error: | |
| logger.error(f"重置校对token请求次数时发生错误: {str(error)}", "TokenManager") | |
| return False | |
| def get_next_token_for_model(self, model_id, is_return=False): | |
| normalized_model = self.normalize_model_name(model_id) | |
| if normalized_model not in self.token_model_map or not self.token_model_map[normalized_model]: | |
| return None | |
| token_entry = self.token_model_map[normalized_model][0] | |
| if is_return: | |
| return token_entry["token"] | |
| if token_entry: | |
| if token_entry["StartCallTime"] is None: | |
| token_entry["StartCallTime"] = int(time.time() * 1000) | |
| if not self.token_reset_switch: | |
| self.start_token_reset_process() | |
| self.token_reset_switch = True | |
| token_entry["RequestCount"] += 1 | |
| if token_entry["RequestCount"] > self.model_config[normalized_model]["RequestFrequency"]: | |
| self.remove_token_from_model(normalized_model, token_entry["token"]) | |
| next_token_entry = self.token_model_map[normalized_model][0] if self.token_model_map[normalized_model] else None | |
| return next_token_entry["token"] if next_token_entry else None | |
| sso = token_entry["token"].split("sso=")[1].split(";")[0] | |
| if sso in self.token_status_map and normalized_model in self.token_status_map[sso]: | |
| if token_entry["RequestCount"] == self.model_config[normalized_model]["RequestFrequency"]: | |
| self.token_status_map[sso][normalized_model]["isValid"] = False | |
| self.token_status_map[sso][normalized_model]["invalidatedTime"] = int(time.time() * 1000) | |
| self.token_status_map[sso][normalized_model]["totalRequestCount"] += 1 | |
| return token_entry["token"] | |
| return None | |
| def remove_token_from_model(self, model_id, token): | |
| normalized_model = self.normalize_model_name(model_id) | |
| if normalized_model not in self.token_model_map: | |
| logger.error(f"模型 {normalized_model} 不存在", "TokenManager") | |
| return False | |
| model_tokens = self.token_model_map[normalized_model] | |
| token_index = next((i for i, entry in enumerate(model_tokens) if entry["token"] == token), -1) | |
| if token_index != -1: | |
| removed_token_entry = model_tokens.pop(token_index) | |
| self.expired_tokens.add(( | |
| removed_token_entry["token"], | |
| normalized_model, | |
| int(time.time() * 1000) | |
| )) | |
| if not self.token_reset_switch: | |
| self.start_token_reset_process() | |
| self.token_reset_switch = True | |
| logger.info(f"模型{model_id}的令牌已失效,已成功移除令牌: {token}", "TokenManager") | |
| return True | |
| logger.error(f"在模型 {normalized_model} 中未找到 token: {token}", "TokenManager") | |
| return False | |
| def get_expired_tokens(self): | |
| return list(self.expired_tokens) | |
| def normalize_model_name(self, model): | |
| if model.startswith('grok-') and 'deepsearch' not in model and 'reasoning' not in model: | |
| return '-'.join(model.split('-')[:2]) | |
| return model | |
| def get_token_count_for_model(self, model_id): | |
| normalized_model = self.normalize_model_name(model_id) | |
| return len(self.token_model_map.get(normalized_model, [])) | |
| def get_remaining_token_request_capacity(self): | |
| remaining_capacity_map = {} | |
| for model in self.model_config.keys(): | |
| model_tokens = self.token_model_map.get(model, []) | |
| model_request_frequency = self.model_config[model]["RequestFrequency"] | |
| total_used_requests = sum(token_entry.get("RequestCount", 0) for token_entry in model_tokens) | |
| remaining_capacity = (len(model_tokens) * model_request_frequency) - total_used_requests | |
| remaining_capacity_map[model] = max(0, remaining_capacity) | |
| return remaining_capacity_map | |
| def get_token_array_for_model(self, model_id): | |
| normalized_model = self.normalize_model_name(model_id) | |
| return self.token_model_map.get(normalized_model, []) | |
| def start_token_reset_process(self): | |
| def reset_expired_tokens(): | |
| now = int(time.time() * 1000) | |
| tokens_to_remove = set() | |
| for token_info in self.expired_tokens: | |
| token, model, expired_time = token_info | |
| expiration_time = self.model_config[model]["ExpirationTime"] | |
| if now - expired_time >= expiration_time: | |
| if not any(entry["token"] == token for entry in self.token_model_map.get(model, [])): | |
| if model not in self.token_model_map: | |
| self.token_model_map[model] = [] | |
| self.token_model_map[model].append({ | |
| "token": token, | |
| "RequestCount": 0, | |
| "AddedTime": now, | |
| "StartCallTime": None | |
| }) | |
| sso = token.split("sso=")[1].split(";")[0] | |
| if sso in self.token_status_map and model in self.token_status_map[sso]: | |
| self.token_status_map[sso][model]["isValid"] = True | |
| self.token_status_map[sso][model]["invalidatedTime"] = None | |
| self.token_status_map[sso][model]["totalRequestCount"] = 0 | |
| tokens_to_remove.add(token_info) | |
| self.expired_tokens -= tokens_to_remove | |
| for model in self.model_config.keys(): | |
| if model not in self.token_model_map: | |
| continue | |
| for token_entry in self.token_model_map[model]: | |
| if not token_entry.get("StartCallTime"): | |
| continue | |
| expiration_time = self.model_config[model]["ExpirationTime"] | |
| if now - token_entry["StartCallTime"] >= expiration_time: | |
| sso = token_entry["token"].split("sso=")[1].split(";")[0] | |
| if sso in self.token_status_map and model in self.token_status_map[sso]: | |
| self.token_status_map[sso][model]["isValid"] = True | |
| self.token_status_map[sso][model]["invalidatedTime"] = None | |
| self.token_status_map[sso][model]["totalRequestCount"] = 0 | |
| token_entry["RequestCount"] = 0 | |
| token_entry["StartCallTime"] = None | |
| # 启动一个线程执行定时任务,每小时执行一次 | |
| def run_timer(): | |
| while True: | |
| reset_expired_tokens() | |
| time.sleep(3600) | |
| timer_thread = threading.Thread(target=run_timer) | |
| timer_thread.daemon = True | |
| timer_thread.start() | |
| def get_all_tokens(self): | |
| all_tokens = set() | |
| for model_tokens in self.token_model_map.values(): | |
| for entry in model_tokens: | |
| all_tokens.add(entry["token"]) | |
| return list(all_tokens) | |
| def get_current_token(self, model_id): | |
| normalized_model = self.normalize_model_name(model_id) | |
| if normalized_model not in self.token_model_map or not self.token_model_map[normalized_model]: | |
| return None | |
| token_entry = self.token_model_map[normalized_model][0] | |
| return token_entry["token"] | |
| def get_total_token_count(self): | |
| """获取总令牌数""" | |
| return len(self.get_all_tokens()) | |
| def get_total_request_count(self): | |
| """获取总请求次数""" | |
| total_count = 0 | |
| for sso_status in self.token_status_map.values(): | |
| for model_status in sso_status.values(): | |
| total_count += model_status.get("totalRequestCount", 0) | |
| return total_count | |
| def get_token_status_map(self): | |
| """获取令牌状态映射""" | |
| return self.token_status_map |