g2api-test / src /core /auth_token_manager.py
misonL's picture
Initial project commit with gitignore
df4585d
Raw
History Blame Contribute Delete
12.5 kB
# src/core/auth_token_manager.py
import time
import json
import threading
from src.core.logger import logger # 从新的位置导入 logger
class AuthTokenManager:
def __init__(self):
self.token_model_map = {}
self.expired_tokens = set()
self.token_status_map = {}
self.model_config = {
"grok-2": {
"RequestFrequency": 30,
"ExpirationTime": 1 * 60 * 60 * 1000 # 1小时
},
"grok-3": {
"RequestFrequency": 20,
"ExpirationTime": 2 * 60 * 60 * 1000 # 2小时
},
"grok-3-deepsearch": {
"RequestFrequency": 10,
"ExpirationTime": 24 * 60 * 60 * 1000 # 24小时
},
"grok-3-reasoning": {
"RequestFrequency": 10,
"ExpirationTime": 24 * 60 * 60 * 1000 # 24小时
}
}
self.token_reset_switch = False
self.token_reset_timer = None
def add_token(self, token):
sso = token.split("sso=")[1].split(";")[0]
for model in self.model_config.keys():
if model not in self.token_model_map:
self.token_model_map[model] = []
if sso not in self.token_status_map:
self.token_status_map[sso] = {}
existing_token_entry = next((entry for entry in self.token_model_map[model] if entry["token"] == token), None)
if not existing_token_entry:
self.token_model_map[model].append({
"token": token,
"RequestCount": 0,
"AddedTime": int(time.time() * 1000),
"StartCallTime": None
})
if model not in self.token_status_map[sso]:
self.token_status_map[sso][model] = {
"isValid": True,
"invalidatedTime": None,
"totalRequestCount": 0
}
def set_token(self, token):
models = list(self.model_config.keys())
self.token_model_map = {model: [{
"token": token,
"RequestCount": 0,
"AddedTime": int(time.time() * 1000),
"StartCallTime": None
}] for model in models}
sso = token.split("sso=")[1].split(";")[0]
self.token_status_map[sso] = {model: {
"isValid": True,
"invalidatedTime": None,
"totalRequestCount": 0
} for model in models}
def delete_token(self, token):
try:
sso = token.split("sso=")[1].split(";")[0]
for model in self.token_model_map:
self.token_model_map[model] = [entry for entry in self.token_model_map[model] if entry["token"] != token]
if sso in self.token_status_map:
del self.token_status_map[sso]
logger.info(f"令牌已成功移除: {token}", "TokenManager")
return True
except Exception as error:
logger.error(f"令牌删除失败: {str(error)}")
return False
def reduce_token_request_count(self, model_id, count):
try:
normalized_model = self.normalize_model_name(model_id)
if normalized_model not in self.token_model_map:
logger.error(f"模型 {normalized_model} 不存在", "TokenManager")
return False
if not self.token_model_map[normalized_model]:
logger.error(f"模型 {normalized_model} 没有可用的token", "TokenManager")
return False
token_entry = self.token_model_map[normalized_model][0]
# 确保RequestCount不会小于0
new_count = max(0, token_entry["RequestCount"] - count)
reduction = token_entry["RequestCount"] - new_count
token_entry["RequestCount"] = new_count
# 更新token状态
if token_entry["token"]:
sso = token_entry["token"].split("sso=")[1].split(";")[0]
if sso in self.token_status_map and normalized_model in self.token_status_map[sso]:
self.token_status_map[sso][normalized_model]["totalRequestCount"] = max(
0, self.token_status_map[sso][normalized_model]["totalRequestCount"] - reduction
)
return True
except Exception as error:
logger.error(f"重置校对token请求次数时发生错误: {str(error)}", "TokenManager")
return False
def get_next_token_for_model(self, model_id, is_return=False):
normalized_model = self.normalize_model_name(model_id)
if normalized_model not in self.token_model_map or not self.token_model_map[normalized_model]:
return None
token_entry = self.token_model_map[normalized_model][0]
if is_return:
return token_entry["token"]
if token_entry:
if token_entry["StartCallTime"] is None:
token_entry["StartCallTime"] = int(time.time() * 1000)
if not self.token_reset_switch:
self.start_token_reset_process()
self.token_reset_switch = True
token_entry["RequestCount"] += 1
if token_entry["RequestCount"] > self.model_config[normalized_model]["RequestFrequency"]:
self.remove_token_from_model(normalized_model, token_entry["token"])
next_token_entry = self.token_model_map[normalized_model][0] if self.token_model_map[normalized_model] else None
return next_token_entry["token"] if next_token_entry else None
sso = token_entry["token"].split("sso=")[1].split(";")[0]
if sso in self.token_status_map and normalized_model in self.token_status_map[sso]:
if token_entry["RequestCount"] == self.model_config[normalized_model]["RequestFrequency"]:
self.token_status_map[sso][normalized_model]["isValid"] = False
self.token_status_map[sso][normalized_model]["invalidatedTime"] = int(time.time() * 1000)
self.token_status_map[sso][normalized_model]["totalRequestCount"] += 1
return token_entry["token"]
return None
def remove_token_from_model(self, model_id, token):
normalized_model = self.normalize_model_name(model_id)
if normalized_model not in self.token_model_map:
logger.error(f"模型 {normalized_model} 不存在", "TokenManager")
return False
model_tokens = self.token_model_map[normalized_model]
token_index = next((i for i, entry in enumerate(model_tokens) if entry["token"] == token), -1)
if token_index != -1:
removed_token_entry = model_tokens.pop(token_index)
self.expired_tokens.add((
removed_token_entry["token"],
normalized_model,
int(time.time() * 1000)
))
if not self.token_reset_switch:
self.start_token_reset_process()
self.token_reset_switch = True
logger.info(f"模型{model_id}的令牌已失效,已成功移除令牌: {token}", "TokenManager")
return True
logger.error(f"在模型 {normalized_model} 中未找到 token: {token}", "TokenManager")
return False
def get_expired_tokens(self):
return list(self.expired_tokens)
def normalize_model_name(self, model):
if model.startswith('grok-') and 'deepsearch' not in model and 'reasoning' not in model:
return '-'.join(model.split('-')[:2])
return model
def get_token_count_for_model(self, model_id):
normalized_model = self.normalize_model_name(model_id)
return len(self.token_model_map.get(normalized_model, []))
def get_remaining_token_request_capacity(self):
remaining_capacity_map = {}
for model in self.model_config.keys():
model_tokens = self.token_model_map.get(model, [])
model_request_frequency = self.model_config[model]["RequestFrequency"]
total_used_requests = sum(token_entry.get("RequestCount", 0) for token_entry in model_tokens)
remaining_capacity = (len(model_tokens) * model_request_frequency) - total_used_requests
remaining_capacity_map[model] = max(0, remaining_capacity)
return remaining_capacity_map
def get_token_array_for_model(self, model_id):
normalized_model = self.normalize_model_name(model_id)
return self.token_model_map.get(normalized_model, [])
def start_token_reset_process(self):
def reset_expired_tokens():
now = int(time.time() * 1000)
tokens_to_remove = set()
for token_info in self.expired_tokens:
token, model, expired_time = token_info
expiration_time = self.model_config[model]["ExpirationTime"]
if now - expired_time >= expiration_time:
if not any(entry["token"] == token for entry in self.token_model_map.get(model, [])):
if model not in self.token_model_map:
self.token_model_map[model] = []
self.token_model_map[model].append({
"token": token,
"RequestCount": 0,
"AddedTime": now,
"StartCallTime": None
})
sso = token.split("sso=")[1].split(";")[0]
if sso in self.token_status_map and model in self.token_status_map[sso]:
self.token_status_map[sso][model]["isValid"] = True
self.token_status_map[sso][model]["invalidatedTime"] = None
self.token_status_map[sso][model]["totalRequestCount"] = 0
tokens_to_remove.add(token_info)
self.expired_tokens -= tokens_to_remove
for model in self.model_config.keys():
if model not in self.token_model_map:
continue
for token_entry in self.token_model_map[model]:
if not token_entry.get("StartCallTime"):
continue
expiration_time = self.model_config[model]["ExpirationTime"]
if now - token_entry["StartCallTime"] >= expiration_time:
sso = token_entry["token"].split("sso=")[1].split(";")[0]
if sso in self.token_status_map and model in self.token_status_map[sso]:
self.token_status_map[sso][model]["isValid"] = True
self.token_status_map[sso][model]["invalidatedTime"] = None
self.token_status_map[sso][model]["totalRequestCount"] = 0
token_entry["RequestCount"] = 0
token_entry["StartCallTime"] = None
# 启动一个线程执行定时任务,每小时执行一次
def run_timer():
while True:
reset_expired_tokens()
time.sleep(3600)
timer_thread = threading.Thread(target=run_timer)
timer_thread.daemon = True
timer_thread.start()
def get_all_tokens(self):
all_tokens = set()
for model_tokens in self.token_model_map.values():
for entry in model_tokens:
all_tokens.add(entry["token"])
return list(all_tokens)
def get_current_token(self, model_id):
normalized_model = self.normalize_model_name(model_id)
if normalized_model not in self.token_model_map or not self.token_model_map[normalized_model]:
return None
token_entry = self.token_model_map[normalized_model][0]
return token_entry["token"]
def get_total_token_count(self):
"""获取总令牌数"""
return len(self.get_all_tokens())
def get_total_request_count(self):
"""获取总请求次数"""
total_count = 0
for sso_status in self.token_status_map.values():
for model_status in sso_status.values():
total_count += model_status.get("totalRequestCount", 0)
return total_count
def get_token_status_map(self):
"""获取令牌状态映射"""
return self.token_status_map