2api / src /storage /mongodb_manager.py
lin7zhi's picture
Upload folder using huggingface_hub
69fec20 verified
"""
MongoDB 存储管理器
"""
import os
import time
import re
from typing import Any, Dict, List, Optional
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
from log import log
class MongoDBManager:
"""MongoDB 数据库管理器"""
# 状态字段常量
STATE_FIELDS = {
"error_codes",
"disabled",
"last_success",
"user_email",
"model_cooldowns",
}
def __init__(self):
self._client: Optional[AsyncIOMotorClient] = None
self._db: Optional[AsyncIOMotorDatabase] = None
self._initialized = False
# 内存配置缓存 - 初始化时加载一次
self._config_cache: Dict[str, Any] = {}
self._config_loaded = False
async def initialize(self) -> None:
"""初始化 MongoDB 连接"""
if self._initialized:
return
try:
mongodb_uri = os.getenv("MONGODB_URI")
if not mongodb_uri:
raise ValueError("MONGODB_URI environment variable not set")
database_name = os.getenv("MONGODB_DATABASE", "gcli2api")
self._client = AsyncIOMotorClient(mongodb_uri)
self._db = self._client[database_name]
# 测试连接
await self._db.command("ping")
# 创建索引
await self._create_indexes()
# 加载配置到内存
await self._load_config_cache()
self._initialized = True
log.info(f"MongoDB storage initialized (database: {database_name})")
except Exception as e:
log.error(f"Error initializing MongoDB: {e}")
raise
async def _create_indexes(self):
"""创建索引"""
credentials_collection = self._db["credentials"]
antigravity_credentials_collection = self._db["antigravity_credentials"]
# 创建普通凭证索引
await credentials_collection.create_index("filename", unique=True)
await credentials_collection.create_index("disabled")
await credentials_collection.create_index("rotation_order")
# 复合索引
await credentials_collection.create_index([("disabled", 1), ("rotation_order", 1)])
# 如果经常按错误码筛选,可以添加此索引
await credentials_collection.create_index("error_codes")
# 创建 Antigravity 凭证索引
await antigravity_credentials_collection.create_index("filename", unique=True)
await antigravity_credentials_collection.create_index("disabled")
await antigravity_credentials_collection.create_index("rotation_order")
# 复合索引
await antigravity_credentials_collection.create_index([("disabled", 1), ("rotation_order", 1)])
# 如果经常按错误码筛选,可以添加此索引
await antigravity_credentials_collection.create_index("error_codes")
log.debug("MongoDB indexes created")
async def _load_config_cache(self):
"""加载配置到内存缓存(仅在初始化时调用一次)"""
if self._config_loaded:
return
try:
config_collection = self._db["config"]
cursor = config_collection.find({})
async for doc in cursor:
self._config_cache[doc["key"]] = doc.get("value")
self._config_loaded = True
log.debug(f"Loaded {len(self._config_cache)} config items into cache")
except Exception as e:
log.error(f"Error loading config cache: {e}")
self._config_cache = {}
async def close(self) -> None:
"""关闭 MongoDB 连接"""
if self._client:
self._client.close()
self._client = None
self._db = None
self._initialized = False
log.debug("MongoDB storage closed")
def _ensure_initialized(self):
"""确保已初始化"""
if not self._initialized:
raise RuntimeError("MongoDB manager not initialized")
def _get_collection_name(self, mode: str) -> str:
"""根据 mode 获取对应的集合名"""
if mode == "antigravity":
return "antigravity_credentials"
elif mode == "geminicli":
return "credentials"
else:
raise ValueError(f"Invalid mode: {mode}. Must be 'geminicli' or 'antigravity'")
# ============ SQL 方法 ============
async def get_next_available_credential(
self, mode: str = "geminicli", model_key: Optional[str] = None
) -> Optional[tuple[str, Dict[str, Any]]]:
"""
随机获取一个可用凭证(负载均衡)
- 未禁用
- 如果提供了 model_key,还会检查模型级冷却
- 随机选择
Args:
mode: 凭证模式 ("geminicli" 或 "antigravity")
model_key: 模型键(用于模型级冷却检查,antigravity 用模型名,gcli 用 pro/flash)
Note:
- 对于 antigravity: model_key 是具体模型名(如 "gemini-2.0-flash-exp")
- 对于 gcli: model_key 是 "pro" 或 "flash"
- 使用聚合管道在数据库层面过滤冷却状态,性能更优
"""
self._ensure_initialized()
try:
collection_name = self._get_collection_name(mode)
collection = self._db[collection_name]
current_time = time.time()
# 构建聚合管道
pipeline = [
# 第一步: 筛选未禁用的凭证
{"$match": {"disabled": False}},
]
# 如果提供了 model_key,添加冷却检查
if model_key:
pipeline.extend([
# 第二步: 添加冷却状态字段
{
"$addFields": {
"is_available": {
"$or": [
# model_cooldowns 中没有该 model_key
{"$not": {"$ifNull": [f"$model_cooldowns.{model_key}", False]}},
# 或者冷却时间已过期
{"$lte": [f"$model_cooldowns.{model_key}", current_time]}
]
}
}
},
# 第三步: 只保留可用的凭证
{"$match": {"is_available": True}},
])
# 第四步: 随机抽取一个
pipeline.append({"$sample": {"size": 1}})
# 第五步: 只投影需要的字段
pipeline.append({
"$project": {
"filename": 1,
"credential_data": 1,
"_id": 0
}
})
# 执行聚合
docs = await collection.aggregate(pipeline).to_list(length=1)
if docs:
doc = docs[0]
return doc["filename"], doc.get("credential_data")
return None
except Exception as e:
log.error(f"Error getting next available credential (mode={mode}, model_key={model_key}): {e}")
return None
async def get_available_credentials_list(self, mode: str = "geminicli") -> List[str]:
"""
获取所有可用凭证列表
- 未禁用
- 按轮换顺序排序
"""
self._ensure_initialized()
try:
collection_name = self._get_collection_name(mode)
collection = self._db[collection_name]
pipeline = [
{"$match": {"disabled": False}},
{"$sort": {"rotation_order": 1}},
{"$project": {"filename": 1, "_id": 0}}
]
docs = await collection.aggregate(pipeline).to_list(length=None)
return [doc["filename"] for doc in docs]
except Exception as e:
log.error(f"Error getting available credentials list (mode={mode}): {e}")
return []
# ============ StorageBackend 协议方法 ============
async def store_credential(self, filename: str, credential_data: Dict[str, Any], mode: str = "geminicli") -> bool:
"""存储或更新凭证"""
self._ensure_initialized()
try:
collection_name = self._get_collection_name(mode)
collection = self._db[collection_name]
current_ts = time.time()
# 使用 upsert + $setOnInsert
# 如果文档存在,只更新 credential_data 和 updated_at
# 如果文档不存在,设置所有默认字段
# 先尝试更新现有文档
result = await collection.update_one(
{"filename": filename},
{
"$set": {
"credential_data": credential_data,
"updated_at": current_ts,
}
}
)
# 如果没有匹配到(新凭证),需要插入
if result.matched_count == 0:
# 获取下一个 rotation_order
pipeline = [
{"$group": {"_id": None, "max_order": {"$max": "$rotation_order"}}},
{"$project": {"_id": 0, "next_order": {"$add": ["$max_order", 1]}}}
]
result_list = await collection.aggregate(pipeline).to_list(length=1)
next_order = result_list[0]["next_order"] if result_list else 0
# 插入新凭证(使用 insert_one,因为我们已经确认不存在)
try:
await collection.insert_one({
"filename": filename,
"credential_data": credential_data,
"disabled": False,
"error_codes": [],
"last_success": current_ts,
"user_email": None,
"model_cooldowns": {},
"rotation_order": next_order,
"call_count": 0,
"created_at": current_ts,
"updated_at": current_ts,
})
except Exception as insert_error:
# 处理并发插入导致的重复键错误
if "duplicate key" in str(insert_error).lower():
# 重试更新
await collection.update_one(
{"filename": filename},
{"$set": {"credential_data": credential_data, "updated_at": current_ts}}
)
else:
raise
log.debug(f"Stored credential: {filename} (mode={mode})")
return True
except Exception as e:
log.error(f"Error storing credential {filename}: {e}")
return False
async def get_credential(self, filename: str, mode: str = "geminicli") -> Optional[Dict[str, Any]]:
"""获取凭证数据,支持basename匹配以兼容旧数据"""
self._ensure_initialized()
try:
collection_name = self._get_collection_name(mode)
collection = self._db[collection_name]
# 首先尝试精确匹配,只投影需要的字段
doc = await collection.find_one(
{"filename": filename},
{"credential_data": 1, "_id": 0}
)
if doc:
return doc.get("credential_data")
# 如果精确匹配失败,尝试使用basename匹配(处理包含路径的旧数据)
# 直接使用 $regex 结尾匹配,移除重复的 $or 条件
regex_pattern = re.escape(filename)
doc = await collection.find_one(
{"filename": {"$regex": f".*{regex_pattern}$"}},
{"credential_data": 1, "_id": 0}
)
if doc:
return doc.get("credential_data")
return None
except Exception as e:
log.error(f"Error getting credential {filename}: {e}")
return None
async def list_credentials(self, mode: str = "geminicli") -> List[str]:
"""列出所有凭证文件名"""
self._ensure_initialized()
try:
collection_name = self._get_collection_name(mode)
collection = self._db[collection_name]
# 使用聚合管道
pipeline = [
{"$sort": {"rotation_order": 1}},
{"$project": {"filename": 1, "_id": 0}}
]
docs = await collection.aggregate(pipeline).to_list(length=None)
return [doc["filename"] for doc in docs]
except Exception as e:
log.error(f"Error listing credentials: {e}")
return []
async def delete_credential(self, filename: str, mode: str = "geminicli") -> bool:
"""删除凭证,支持basename匹配以兼容旧数据"""
self._ensure_initialized()
try:
collection_name = self._get_collection_name(mode)
collection = self._db[collection_name]
# 首先尝试精确匹配删除
result = await collection.delete_one({"filename": filename})
deleted_count = result.deleted_count
# 如果精确匹配没有删除任何记录,尝试basename匹配
if deleted_count == 0:
regex_pattern = re.escape(filename)
result = await collection.delete_one({
"filename": {"$regex": f".*{regex_pattern}$"}
})
deleted_count = result.deleted_count
if deleted_count > 0:
log.debug(f"Deleted {deleted_count} credential(s): {filename} (mode={mode})")
return True
else:
log.warning(f"No credential found to delete: {filename} (mode={mode})")
return False
except Exception as e:
log.error(f"Error deleting credential {filename}: {e}")
return False
async def get_duplicate_credentials_by_email(self, mode: str = "geminicli") -> Dict[str, Any]:
"""
获取按邮箱分组的重复凭证信息(只查询邮箱和文件名,不加载完整凭证数据)
用于去重操作
Args:
mode: 凭证模式 ("geminicli" 或 "antigravity")
Returns:
包含 email_groups(邮箱分组)、duplicate_count(重复数量)、no_email_count(无邮箱数量)的字典
"""
self._ensure_initialized()
try:
collection_name = self._get_collection_name(mode)
collection = self._db[collection_name]
# 使用聚合管道,只查询 filename 和 user_email 字段
pipeline = [
{
"$project": {
"filename": 1,
"user_email": 1,
"_id": 0
}
},
{
"$sort": {"filename": 1}
}
]
docs = await collection.aggregate(pipeline).to_list(length=None)
# 按邮箱分组
email_to_files = {}
no_email_files = []
for doc in docs:
filename = doc.get("filename")
user_email = doc.get("user_email")
if user_email:
if user_email not in email_to_files:
email_to_files[user_email] = []
email_to_files[user_email].append(filename)
else:
no_email_files.append(filename)
# 找出重复的邮箱组
duplicate_groups = []
total_duplicate_count = 0
for email, files in email_to_files.items():
if len(files) > 1:
# 保留第一个文件,其他为重复
duplicate_groups.append({
"email": email,
"kept_file": files[0],
"duplicate_files": files[1:],
"duplicate_count": len(files) - 1,
})
total_duplicate_count += len(files) - 1
return {
"email_groups": email_to_files,
"duplicate_groups": duplicate_groups,
"duplicate_count": total_duplicate_count,
"no_email_files": no_email_files,
"no_email_count": len(no_email_files),
"unique_email_count": len(email_to_files),
"total_count": len(docs),
}
except Exception as e:
log.error(f"Error getting duplicate credentials by email: {e}")
return {
"email_groups": {},
"duplicate_groups": [],
"duplicate_count": 0,
"no_email_files": [],
"no_email_count": 0,
"unique_email_count": 0,
"total_count": 0,
}
async def update_credential_state(
self, filename: str, state_updates: Dict[str, Any], mode: str = "geminicli"
) -> bool:
"""更新凭证状态,支持basename匹配以兼容旧数据"""
self._ensure_initialized()
try:
collection_name = self._get_collection_name(mode)
collection = self._db[collection_name]
# 过滤只更新状态字段
valid_updates = {
k: v for k, v in state_updates.items() if k in self.STATE_FIELDS
}
if not valid_updates:
return True
valid_updates["updated_at"] = time.time()
# 首先尝试精确匹配更新
result = await collection.update_one(
{"filename": filename}, {"$set": valid_updates}
)
updated_count = result.modified_count + result.matched_count
# 如果精确匹配没有更新任何记录,尝试basename匹配
if updated_count == 0:
regex_pattern = re.escape(filename)
result = await collection.update_one(
{"filename": {"$regex": f".*{regex_pattern}$"}},
{"$set": valid_updates}
)
updated_count = result.modified_count + result.matched_count
return updated_count > 0
except Exception as e:
log.error(f"Error updating credential state {filename}: {e}")
return False
async def get_credential_state(self, filename: str, mode: str = "geminicli") -> Dict[str, Any]:
"""获取凭证状态,支持basename匹配以兼容旧数据"""
self._ensure_initialized()
try:
collection_name = self._get_collection_name(mode)
collection = self._db[collection_name]
# 首先尝试精确匹配
doc = await collection.find_one({"filename": filename})
if doc:
return {
"disabled": doc.get("disabled", False),
"error_codes": doc.get("error_codes", []),
"last_success": doc.get("last_success", time.time()),
"user_email": doc.get("user_email"),
"model_cooldowns": doc.get("model_cooldowns", {}),
}
# 如果精确匹配失败,尝试basename匹配
regex_pattern = re.escape(filename)
doc = await collection.find_one({
"filename": {"$regex": f".*{regex_pattern}$"}
})
if doc:
return {
"disabled": doc.get("disabled", False),
"error_codes": doc.get("error_codes", []),
"last_success": doc.get("last_success", time.time()),
"user_email": doc.get("user_email"),
"model_cooldowns": doc.get("model_cooldowns", {}),
}
# 返回默认状态
return {
"disabled": False,
"error_codes": [],
"last_success": time.time(),
"user_email": None,
"model_cooldowns": {},
}
except Exception as e:
log.error(f"Error getting credential state {filename}: {e}")
return {}
async def get_all_credential_states(self, mode: str = "geminicli") -> Dict[str, Dict[str, Any]]:
"""获取所有凭证状态"""
self._ensure_initialized()
try:
collection_name = self._get_collection_name(mode)
collection = self._db[collection_name]
# 使用投影只获取需要的字段
cursor = collection.find(
{},
projection={
"filename": 1,
"disabled": 1,
"error_codes": 1,
"last_success": 1,
"user_email": 1,
"model_cooldowns": 1,
"_id": 0
}
)
states = {}
current_time = time.time()
async for doc in cursor:
filename = doc["filename"]
model_cooldowns = doc.get("model_cooldowns", {})
# 自动过滤掉已过期的模型CD
if model_cooldowns:
model_cooldowns = {
k: v for k, v in model_cooldowns.items()
if v > current_time
}
states[filename] = {
"disabled": doc.get("disabled", False),
"error_codes": doc.get("error_codes", []),
"last_success": doc.get("last_success", time.time()),
"user_email": doc.get("user_email"),
"model_cooldowns": model_cooldowns,
}
return states
except Exception as e:
log.error(f"Error getting all credential states: {e}")
return {}
async def get_credentials_summary(
self,
offset: int = 0,
limit: Optional[int] = None,
status_filter: str = "all",
mode: str = "geminicli",
error_code_filter: Optional[str] = None,
cooldown_filter: Optional[str] = None
) -> Dict[str, Any]:
"""
获取凭证的摘要信息(不包含完整凭证数据)- 支持分页和状态筛选
Args:
offset: 跳过的记录数(默认0)
limit: 返回的最大记录数(None表示返回所有)
status_filter: 状态筛选(all=全部, enabled=仅启用, disabled=仅禁用)
mode: 凭证模式 ("geminicli" 或 "antigravity")
error_code_filter: 错误码筛选(格式如"400"或"403",筛选包含该错误码的凭证)
cooldown_filter: 冷却状态筛选("in_cooldown"=冷却中, "no_cooldown"=未冷却)
Returns:
包含 items(凭证列表)、total(总数)、offset、limit 的字典
"""
self._ensure_initialized()
try:
# 根据 mode 选择集合名
collection_name = self._get_collection_name(mode)
collection = self._db[collection_name]
# 构建查询条件
query = {}
if status_filter == "enabled":
query["disabled"] = False
elif status_filter == "disabled":
query["disabled"] = True
# 错误码筛选 - 兼容存储为数字或字符串的情况
if error_code_filter and str(error_code_filter).strip().lower() != "all":
filter_value = str(error_code_filter).strip()
query_values = [filter_value]
try:
query_values.append(int(filter_value))
except ValueError:
pass
query["error_codes"] = {"$in": query_values}
# 计算全局统计数据(不受筛选条件影响)
global_stats = {"total": 0, "normal": 0, "disabled": 0}
stats_pipeline = [
{
"$group": {
"_id": "$disabled",
"count": {"$sum": 1}
}
}
]
stats_result = await collection.aggregate(stats_pipeline).to_list(length=10)
for item in stats_result:
count = item["count"]
global_stats["total"] += count
if item["_id"]:
global_stats["disabled"] = count
else:
global_stats["normal"] = count
# 获取所有匹配的文档(用于冷却筛选,因为需要在Python中判断)
cursor = collection.find(
query,
projection={
"filename": 1,
"disabled": 1,
"error_codes": 1,
"last_success": 1,
"user_email": 1,
"rotation_order": 1,
"model_cooldowns": 1,
"_id": 0
}
).sort("rotation_order", 1)
all_summaries = []
current_time = time.time()
async for doc in cursor:
model_cooldowns = doc.get("model_cooldowns", {})
# 自动过滤掉已过期的模型CD
active_cooldowns = {}
if model_cooldowns:
active_cooldowns = {
k: v for k, v in model_cooldowns.items()
if v > current_time
}
summary = {
"filename": doc["filename"],
"disabled": doc.get("disabled", False),
"error_codes": doc.get("error_codes", []),
"last_success": doc.get("last_success", current_time),
"user_email": doc.get("user_email"),
"rotation_order": doc.get("rotation_order", 0),
"model_cooldowns": active_cooldowns,
}
# 应用冷却筛选
if cooldown_filter == "in_cooldown":
# 只保留有冷却的凭证
if active_cooldowns:
all_summaries.append(summary)
elif cooldown_filter == "no_cooldown":
# 只保留没有冷却的凭证
if not active_cooldowns:
all_summaries.append(summary)
else:
# 不筛选冷却状态
all_summaries.append(summary)
# 应用分页
total_count = len(all_summaries)
if limit is not None:
summaries = all_summaries[offset:offset + limit]
else:
summaries = all_summaries[offset:]
return {
"items": summaries,
"total": total_count,
"offset": offset,
"limit": limit,
"stats": global_stats,
}
except Exception as e:
log.error(f"Error getting credentials summary: {e}")
return {
"items": [],
"total": 0,
"offset": offset,
"limit": limit,
"stats": {"total": 0, "normal": 0, "disabled": 0},
}
# ============ 配置管理(内存缓存)============
async def set_config(self, key: str, value: Any) -> bool:
"""设置配置(写入数据库 + 更新内存缓存)"""
self._ensure_initialized()
try:
config_collection = self._db["config"]
await config_collection.update_one(
{"key": key},
{"$set": {"value": value, "updated_at": time.time()}},
upsert=True,
)
# 更新内存缓存
self._config_cache[key] = value
return True
except Exception as e:
log.error(f"Error setting config {key}: {e}")
return False
async def reload_config_cache(self):
"""重新加载配置缓存(在批量修改配置后调用)"""
self._ensure_initialized()
self._config_loaded = False
await self._load_config_cache()
log.info("Config cache reloaded from database")
async def get_config(self, key: str, default: Any = None) -> Any:
"""获取配置(从内存缓存)"""
self._ensure_initialized()
return self._config_cache.get(key, default)
async def get_all_config(self) -> Dict[str, Any]:
"""获取所有配置(从内存缓存)"""
self._ensure_initialized()
return self._config_cache.copy()
async def delete_config(self, key: str) -> bool:
"""删除配置"""
self._ensure_initialized()
try:
config_collection = self._db["config"]
result = await config_collection.delete_one({"key": key})
# 从内存缓存移除
self._config_cache.pop(key, None)
return result.deleted_count > 0
except Exception as e:
log.error(f"Error deleting config {key}: {e}")
return False
# ============ 模型级冷却管理 ============
async def set_model_cooldown(
self,
filename: str,
model_key: str,
cooldown_until: Optional[float],
mode: str = "geminicli"
) -> bool:
"""
设置特定模型的冷却时间
Args:
filename: 凭证文件名
model_key: 模型键(antigravity 用模型名,gcli 用 pro/flash)
cooldown_until: 冷却截止时间戳(None 表示清除冷却)
mode: 凭证模式 ("geminicli" 或 "antigravity")
Returns:
是否成功
"""
self._ensure_initialized()
try:
collection_name = self._get_collection_name(mode)
collection = self._db[collection_name]
# 使用原子操作直接更新,避免竞态条件
if cooldown_until is None:
# 删除指定模型的冷却
result = await collection.update_one(
{"filename": filename},
{
"$unset": {f"model_cooldowns.{model_key}": ""},
"$set": {"updated_at": time.time()}
}
)
else:
# 设置冷却时间
result = await collection.update_one(
{"filename": filename},
{
"$set": {
f"model_cooldowns.{model_key}": cooldown_until,
"updated_at": time.time()
}
}
)
if result.matched_count == 0:
log.warning(f"Credential {filename} not found")
return False
log.debug(f"Set model cooldown: {filename}, model_key={model_key}, cooldown_until={cooldown_until}")
return True
except Exception as e:
log.error(f"Error setting model cooldown for {filename}: {e}")
return False