gcli2api / src /storage /mongodb_manager.py
a3216's picture
sync: github -> hf space
c50496f
"""
MongoDB 存储管理器
"""
import json
import os
import random
import time
from typing import Any, Dict, List, Optional
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
from log import log
class MongoDBManager:
"""MongoDB 数据库管理器"""
# 状态字段常量
STATE_FIELDS = {
"error_codes",
"error_messages",
"disabled",
"last_success",
"user_email",
"model_cooldowns",
"preview",
"tier",
"enable_credit",
}
@staticmethod
def _escape_model_name(model_name: str) -> str:
"""
转义模型名中的点号,避免 MongoDB 将其解释为嵌套结构
Args:
model_name: 原始模型名 (如 "gemini-2.5-flash")
Returns:
转义后的模型名 (如 "gemini-2-5-flash")
"""
return model_name.replace(".", "-")
def __init__(self):
self._client: Optional[AsyncIOMotorClient] = None
self._db: Optional[AsyncIOMotorDatabase] = None
self._initialized = False
# 内存配置缓存 - 初始化时加载一次
self._config_cache: Dict[str, Any] = {}
self._config_loaded = False
# Redis 缓存(仅当 REDIS_URL 环境变量存在时启用)
self._redis = None
self._redis_enabled: bool = False
async def initialize(self) -> None:
"""初始化 MongoDB 连接"""
if self._initialized:
return
try:
mongodb_uri = os.getenv("MONGODB_URI")
if not mongodb_uri:
raise ValueError("MONGODB_URI environment variable not set")
database_name = os.getenv("MONGODB_DATABASE", "gcli2api")
self._client = AsyncIOMotorClient(mongodb_uri)
self._db = self._client[database_name]
# 测试连接
await self._db.command("ping")
# 创建索引
await self._create_indexes()
# 加载配置到内存
await self._load_config_cache()
self._initialized = True
log.info(f"MongoDB storage initialized (database: {database_name})")
# 尝试初始化 Redis(可选)
await self._init_redis()
except Exception as e:
log.error(f"Error initializing MongoDB: {e}")
raise
async def _create_indexes(self):
"""
创建索引
"""
from pymongo import IndexModel, ASCENDING
credentials_collection = self._db["credentials"]
antigravity_credentials_collection = self._db["antigravity_credentials"]
# ===== Geminicli 凭证索引 =====
geminicli_indexes = [
# 唯一索引 - 用于所有按文件名的精确查询
IndexModel([("filename", ASCENDING)], unique=True, name="idx_filename_unique"),
# 复合索引 - 用于 get_next_available_credential 和 get_available_credentials_list
# 查询模式: {disabled: False} + sort by rotation_order
IndexModel(
[("disabled", ASCENDING), ("rotation_order", ASCENDING)],
name="idx_disabled_rotation"
),
# 单字段索引 - 用于 get_credentials_summary 的错误筛选
IndexModel([("error_codes", ASCENDING)], name="idx_error_codes"),
# 单字段索引 - 用于 get_duplicate_credentials_by_email 的去重查询
IndexModel([("user_email", ASCENDING)], name="idx_user_email"),
]
# ===== Antigravity 凭证索引 =====
antigravity_indexes = [
# 唯一索引
IndexModel([("filename", ASCENDING)], unique=True, name="idx_filename_unique"),
# 复合索引 - 查询模式: {disabled: False} + sort by rotation_order
# 查询模式: {disabled: False} + 可选 sort by rotation_order
IndexModel(
[("disabled", ASCENDING), ("rotation_order", ASCENDING)],
name="idx_disabled_rotation"
),
# 单字段索引 - 错误筛选
IndexModel([("error_codes", ASCENDING)], name="idx_error_codes"),
# 单字段索引 - 去重查询
IndexModel([("user_email", ASCENDING)], name="idx_user_email"),
]
# 并行创建新索引
try:
await credentials_collection.create_indexes(geminicli_indexes)
await antigravity_credentials_collection.create_indexes(antigravity_indexes)
log.debug("MongoDB indexes created successfully")
except Exception as e:
# 如果索引已存在,忽略错误
if "already exists" not in str(e).lower():
log.warning(f"Index creation warning: {e}")
async def _load_config_cache(self):
"""加载配置到内存缓存(仅在初始化时调用一次)"""
if self._config_loaded:
return
try:
config_collection = self._db["config"]
cursor = config_collection.find({})
async for doc in cursor:
self._config_cache[doc["key"]] = doc.get("value")
self._config_loaded = True
log.debug(f"Loaded {len(self._config_cache)} config items into cache")
except Exception as e:
log.error(f"Error loading config cache: {e}")
self._config_cache = {}
# ============ Redis 缓存(可选,仅当 REDIS_URL 存在时启用)============
async def _init_redis(self) -> None:
"""初始化 Redis 连接并重建凭证池缓存(若 REDIS_URL 存在)"""
redis_url = os.getenv("REDIS_URL")
if not redis_url:
return
try:
import redis.asyncio as aioredis # type: ignore
except ImportError:
log.warning("redis package not installed, Redis cache disabled. Run: pip install redis")
return
try:
self._redis = aioredis.from_url(redis_url, decode_responses=True)
await self._redis.ping()
self._redis_enabled = True
log.info("Redis connected, rebuilding credential pool cache...")
# 并行重建两个 mode 的缓存及配置缓存
import asyncio
await asyncio.gather(
self._rebuild_redis_cache("geminicli"),
self._rebuild_redis_cache("antigravity"),
self._load_config_to_redis(),
)
log.info("Redis credential pool cache ready")
except Exception as e:
log.warning(f"Redis init failed, falling back to MongoDB-only mode: {e}")
self._redis = None
self._redis_enabled = False
# ---- Redis key 工具 ----
def _rk_avail(self, mode: str) -> str:
"""所有未禁用凭证的 Redis Set key"""
return f"gcli:avail:{mode}"
def _rk_tier(self, mode: str, tier: str) -> str:
"""按 tier 分桶的未禁用凭证 Redis Set key"""
return f"gcli:tier:{mode}:{tier}"
def _rk_preview(self, mode: str) -> str:
"""preview=True 凭证的 Redis Set key"""
return f"gcli:preview:{mode}"
def _rk_cd(self, mode: str, filename: str, escaped_model: str) -> str:
"""模型冷却 Redis key(带 TTL)"""
return f"gcli:cd:{mode}:{filename}:{escaped_model}"
# ---- Redis 缓存维护 ----
async def _rebuild_redis_cache(self, mode: str) -> None:
"""
从 MongoDB 重建指定 mode 的 Redis 凭证池缓存。
使用临时 key + RENAME 原子替换
"""
if not self._redis:
return
try:
collection = self._db[self._get_collection_name(mode)]
# 同时投影 model_cooldowns、tier、preview,以便重建缓存
projection: Dict[str, Any] = {"filename": 1, "disabled": 1, "model_cooldowns": 1, "tier": 1, "preview": 1, "_id": 0}
avail: List[str] = []
tier_buckets: Dict[str, List[str]] = {} # tier -> [filename, ...]
preview_members: List[str] = []
cooldown_entries: List[tuple] = [] # (cd_key, ttl_seconds, value)
current_time = time.time()
async for doc in collection.find({}, projection=projection):
if not doc.get("disabled", False):
filename = doc["filename"]
avail.append(filename)
# 按 tier 分桶
tier = doc.get("tier") or "pro"
tier_buckets.setdefault(tier, []).append(filename)
# preview 分桶(仅 geminicli)
if mode == "geminicli" and doc.get("preview", True):
preview_members.append(filename)
# 收集未过期的模型冷却,重建 Redis TTL Key
model_cooldowns = doc.get("model_cooldowns") or {}
for escaped_model, cooldown_until in model_cooldowns.items():
if isinstance(cooldown_until, (int, float)) and cooldown_until > current_time:
ttl = int(cooldown_until - current_time)
if ttl > 0:
cd_key = self._rk_cd(mode, filename, escaped_model)
cooldown_entries.append((cd_key, ttl, str(cooldown_until)))
tmp_avail = self._rk_avail(mode) + ":tmp"
pipe = self._redis.pipeline()
# 先写临时 key(此时正式 key 仍完整可用)
pipe.delete(tmp_avail)
if avail:
pipe.sadd(tmp_avail, *avail)
await pipe.execute()
# RENAME 是原子操作:瞬间切换,不存在空窗
pipe2 = self._redis.pipeline()
if avail:
pipe2.rename(tmp_avail, self._rk_avail(mode))
else:
pipe2.delete(self._rk_avail(mode))
pipe2.delete(tmp_avail)
await pipe2.execute()
# 重建 tier 分桶 Set(原子替换)
all_tiers = ("free", "pro", "ultra")
pipe3 = self._redis.pipeline()
for tier in all_tiers:
tier_key = self._rk_tier(mode, tier)
tmp_tier_key = tier_key + ":tmp"
pipe3.delete(tmp_tier_key)
members = tier_buckets.get(tier, [])
if members:
pipe3.sadd(tmp_tier_key, *members)
await pipe3.execute()
pipe4 = self._redis.pipeline()
for tier in all_tiers:
tier_key = self._rk_tier(mode, tier)
tmp_tier_key = tier_key + ":tmp"
members = tier_buckets.get(tier, [])
if members:
pipe4.rename(tmp_tier_key, tier_key)
else:
pipe4.delete(tier_key)
pipe4.delete(tmp_tier_key)
await pipe4.execute()
# 重建 preview 分桶(仅 geminicli)
preview_key = self._rk_preview(mode)
tmp_preview_key = preview_key + ":tmp"
pipe5 = self._redis.pipeline()
pipe5.delete(tmp_preview_key)
if preview_members:
pipe5.sadd(tmp_preview_key, *preview_members)
await pipe5.execute()
pipe6 = self._redis.pipeline()
if preview_members:
pipe6.rename(tmp_preview_key, preview_key)
else:
pipe6.delete(preview_key)
pipe6.delete(tmp_preview_key)
await pipe6.execute()
# 批量恢复未过期的模型冷却 TTL Key
if cooldown_entries:
pipe7 = self._redis.pipeline()
for cd_key, ttl, value in cooldown_entries:
pipe7.setex(cd_key, ttl, value)
await pipe7.execute()
log.debug(
f"Redis cache rebuilt [{mode}]: {len(avail)} avail, "
f"tiers={{{', '.join(f'{t}:{len(tier_buckets.get(t, []))}' for t in all_tiers)}}}, "
f"preview={len(preview_members)}, "
f"{len(cooldown_entries)} cooldown key(s) restored"
)
except Exception as e:
log.warning(f"Redis rebuild cache error [{mode}]: {e}")
async def _redis_add_cred(self, mode: str, filename: str, tier: str = "pro", preview: bool = True) -> None:
"""将凭证加入 Redis 可用池及对应 tier 分桶、preview 分桶"""
if not self._redis_enabled:
return
try:
pipe = self._redis.pipeline()
pipe.sadd(self._rk_avail(mode), filename)
pipe.sadd(self._rk_tier(mode, tier), filename)
if mode == "geminicli" and preview:
pipe.sadd(self._rk_preview(mode), filename)
await pipe.execute()
except Exception as e:
log.warning(f"Redis add_cred error: {e}")
async def _redis_remove_cred(self, mode: str, filename: str, tier: Optional[str] = None) -> None:
"""从 Redis 所有池中移除凭证"""
if not self._redis_enabled:
return
try:
pipe = self._redis.pipeline()
pipe.srem(self._rk_avail(mode), filename)
if tier:
pipe.srem(self._rk_tier(mode, tier), filename)
else:
# tier 未知时从所有分桶中移除
for t in ("free", "pro", "ultra"):
pipe.srem(self._rk_tier(mode, t), filename)
pipe.srem(self._rk_preview(mode), filename)
await pipe.execute()
except Exception as e:
log.warning(f"Redis remove_cred error: {e}")
async def _redis_sync_cred(self, mode: str, filename: str, disabled: bool, tier: str = "pro", preview: bool = True) -> None:
"""根据最新状态同步单个凭证在 Redis 中的集合成员"""
if not self._redis_enabled:
return
try:
pipe = self._redis.pipeline()
if disabled:
pipe.srem(self._rk_avail(mode), filename)
for t in ("free", "pro", "ultra"):
pipe.srem(self._rk_tier(mode, t), filename)
pipe.srem(self._rk_preview(mode), filename)
else:
pipe.sadd(self._rk_avail(mode), filename)
pipe.sadd(self._rk_tier(mode, tier), filename)
if mode == "geminicli" and preview:
pipe.sadd(self._rk_preview(mode), filename)
else:
pipe.srem(self._rk_preview(mode), filename)
await pipe.execute()
except Exception as e:
log.warning(f"Redis sync_cred error: {e}")
async def _get_next_available_from_redis(
self, mode: str, model_name: Optional[str], exclude_free_tier: bool = False, preview_only: bool = False
) -> Optional[tuple]:
"""
Redis 快速路径:随机取候选凭证,跳过冷却中的,返回 (filename, credential_data)。
失败或池为空时返回 None,由调用方降级到 MongoDB。
"""
try:
# 选择候选池优先级:preview_only > exclude_free_tier > 全量池
if preview_only and exclude_free_tier:
# preview 且非 free:preview ∩ (pro ∪ ultra)
preview_set = await self._redis.smembers(self._rk_preview(mode))
pro_members = await self._redis.smembers(self._rk_tier(mode, "pro"))
ultra_members = await self._redis.smembers(self._rk_tier(mode, "ultra"))
non_free = pro_members | ultra_members
all_candidates = list(preview_set & non_free)
if not all_candidates:
log.debug(f"[Redis MISS] mode={mode} preview+non-free: no candidates, fallback to MongoDB")
return None
sample_size = min(len(all_candidates), 10)
candidates = random.sample(all_candidates, sample_size)
elif preview_only:
preview_key = self._rk_preview(mode)
preview_size = await self._redis.scard(preview_key)
if preview_size == 0:
log.debug(f"[Redis MISS] mode={mode} preview_only: pool empty, fallback to MongoDB")
return None
sample_size = min(preview_size, 10)
candidates = await self._redis.srandmember(preview_key, sample_size)
if not candidates:
return None
elif exclude_free_tier:
pro_members = await self._redis.smembers(self._rk_tier(mode, "pro"))
ultra_members = await self._redis.smembers(self._rk_tier(mode, "ultra"))
all_candidates = list(pro_members | ultra_members)
if not all_candidates:
log.debug(f"[Redis MISS] mode={mode} exclude_free: no non-free creds, fallback to MongoDB")
return None
sample_size = min(len(all_candidates), 10)
candidates = random.sample(all_candidates, sample_size)
else:
pool_key = self._rk_avail(mode)
pool_size = await self._redis.scard(pool_key)
if pool_size == 0:
log.debug(f"[Redis MISS] mode={mode} pool_key={pool_key}: pool empty, fallback to MongoDB")
return None
sample_size = min(pool_size, 10)
candidates = await self._redis.srandmember(pool_key, sample_size)
if not candidates:
return None
# 过滤冷却中的凭证
if model_name:
escaped = self._escape_model_name(model_name)
for filename in candidates:
cd_key = self._rk_cd(mode, filename, escaped)
if not await self._redis.exists(cd_key):
credential_data = await self.get_credential(filename, mode)
if mode == "antigravity":
state = await self.get_credential_state(filename, mode)
credential_data = credential_data or {}
credential_data["enable_credit"] = bool(state.get("enable_credit", False))
log.debug(f"[Redis HIT] mode={mode} model={model_name} -> {filename}")
return filename, credential_data
# 所有候选都在冷却中,降级到 MongoDB
log.debug(f"[Redis MISS] mode={mode} model={model_name}: all {len(candidates)} candidates in cooldown, fallback to MongoDB")
return None
else:
filename = candidates[0]
credential_data = await self.get_credential(filename, mode)
if mode == "antigravity":
state = await self.get_credential_state(filename, mode)
credential_data = credential_data or {}
credential_data["enable_credit"] = bool(state.get("enable_credit", False))
log.debug(f"[Redis HIT] mode={mode} -> {filename}")
return filename, credential_data
except Exception as e:
log.warning(f"Redis get_next_available error: {e}")
return None
async def close(self) -> None:
"""关闭 MongoDB 连接"""
if self._redis:
await self._redis.aclose()
self._redis = None
self._redis_enabled = False
if self._client:
self._client.close()
self._client = None
self._db = None
self._initialized = False
log.debug("MongoDB storage closed")
def _ensure_initialized(self):
"""确保已初始化"""
if not self._initialized:
raise RuntimeError("MongoDB manager not initialized")
def _get_collection_name(self, mode: str) -> str:
"""根据 mode 获取对应的集合名"""
if mode == "antigravity":
return "antigravity_credentials"
elif mode == "geminicli":
return "credentials"
else:
raise ValueError(f"Invalid mode: {mode}. Must be 'geminicli' or 'antigravity'")
# ============ SQL 方法 ============
async def get_next_available_credential(
self, mode: str = "geminicli", model_name: Optional[str] = None
) -> Optional[tuple[str, Dict[str, Any]]]:
"""
随机获取一个可用凭证(负载均衡)
- 未禁用
- 如果提供了 model_name,还会检查模型级冷却
- 随机选择
Args:
mode: 凭证模式 ("geminicli" 或 "antigravity")
model_name: 完整模型名(如 "gemini-2.0-flash-exp")
Note:
- 开启 Redis 时:利用 Redis Set 随机选凭证 + TTL key 判断冷却
- 未开启 Redis 时:使用 count + random skip + limit(1)
"""
self._ensure_initialized()
# Redis 快速路径:根据模型名派生过滤标志,直接在 Redis 分桶中筛选
if self._redis_enabled:
model_lower = model_name.lower() if model_name else ""
exclude_free = mode == "geminicli" and "pro" in model_lower
preview_only = mode == "geminicli" and "preview" in model_lower
result = await self._get_next_available_from_redis(
mode, model_name, exclude_free_tier=exclude_free, preview_only=preview_only
)
if result is not None:
return result
# result 为 None:池为空或所有候选都冷却中,降级到 MongoDB 以扩大样本空间
log.debug(f"[MongoDB fallback] mode={mode} model={model_name}")
try:
collection_name = self._get_collection_name(mode)
collection = self._db[collection_name]
current_time = time.time()
# 构建普通查询(避免 $sample 聚合导致全集合扫描)
match_query: Dict[str, Any] = {"disabled": False}
# pro 模型只允许非 free tier 凭证
if mode == "geminicli" and model_name and "pro" in model_name.lower():
match_query["tier"] = {"$ne": "free"}
# preview 模型只允许 preview=True 的凭证
if mode == "geminicli" and model_name and "preview" in model_name.lower():
match_query["preview"] = True
# 冷却检查:直接用 MongoDB 查询表达,无需 $addFields
if model_name:
escaped_model_name = self._escape_model_name(model_name)
field = f"model_cooldowns.{escaped_model_name}"
match_query["$or"] = [
{field: {"$exists": False}},
{field: {"$lte": current_time}},
]
# 统计符合条件的凭证总数(走索引,极快)
count = await collection.count_documents(match_query)
if count == 0:
return None
# 随机偏移 + limit(1),替代 $sample,避免全集合随机排序
skip_n = random.randint(0, count - 1)
projection = {"filename": 1, "credential_data": 1, "enable_credit": 1, "_id": 0}
docs = await collection.find(match_query, projection).skip(skip_n).limit(1).to_list(1)
if docs:
doc = docs[0]
credential_data = doc.get("credential_data") or {}
if mode == "antigravity":
credential_data["enable_credit"] = bool(doc.get("enable_credit", False))
return doc["filename"], credential_data
return None
except Exception as e:
log.error(f"Error getting next available credential (mode={mode}, model_name={model_name}): {e}")
return None
async def get_available_credentials_list(self, mode: str = "geminicli") -> List[str]:
"""
获取所有可用凭证列表
- 未禁用
- 按轮换顺序排序
"""
self._ensure_initialized()
try:
collection_name = self._get_collection_name(mode)
collection = self._db[collection_name]
pipeline = [
{"$match": {"disabled": False}},
{"$sort": {"rotation_order": 1}},
{"$project": {"filename": 1, "_id": 0}}
]
docs = await collection.aggregate(pipeline).to_list(length=None)
return [doc["filename"] for doc in docs]
except Exception as e:
log.error(f"Error getting available credentials list (mode={mode}): {e}")
return []
# ============ StorageBackend 协议方法 ============
async def store_credential(self, filename: str, credential_data: Dict[str, Any], mode: str = "geminicli") -> bool:
"""存储或更新凭证"""
self._ensure_initialized()
# 统一使用 basename 处理文件名
filename = os.path.basename(filename)
try:
collection_name = self._get_collection_name(mode)
collection = self._db[collection_name]
current_ts = time.time()
# 使用 upsert + $setOnInsert
# 如果文档存在,只更新 credential_data 和 updated_at
# 如果文档不存在,设置所有默认字段
# 先尝试更新现有文档
result = await collection.update_one(
{"filename": filename},
{
"$set": {
"credential_data": credential_data,
"updated_at": current_ts,
}
}
)
# 如果没有匹配到(新凭证),需要插入
if result.matched_count == 0:
# 获取下一个 rotation_order
pipeline = [
{"$group": {"_id": None, "max_order": {"$max": "$rotation_order"}}},
{"$project": {"_id": 0, "next_order": {"$add": ["$max_order", 1]}}}
]
result_list = await collection.aggregate(pipeline).to_list(length=1)
next_order = result_list[0]["next_order"] if result_list else 0
# 插入新凭证(使用 insert_one,因为我们已经确认不存在)
try:
new_credential = {
"filename": filename,
"credential_data": credential_data,
"disabled": False,
"error_codes": [],
"error_messages": [],
"last_success": current_ts,
"user_email": None,
"model_cooldowns": {},
"preview": True,
"tier": "pro",
"rotation_order": next_order,
"call_count": 0,
"created_at": current_ts,
"updated_at": current_ts,
}
if mode == "antigravity":
new_credential["enable_credit"] = False
await collection.insert_one(new_credential)
# 新凭证插入成功,添加到 Redis 可用池
await self._redis_add_cred(mode, filename)
except Exception as insert_error:
# 处理并发插入导致的重复键错误
if "duplicate key" in str(insert_error).lower():
# 重试更新(已存在的凭证,无需更新 Redis)
await collection.update_one(
{"filename": filename},
{"$set": {"credential_data": credential_data, "updated_at": current_ts}}
)
else:
raise
log.debug(f"Stored credential: {filename} (mode={mode})")
return True
except Exception as e:
log.error(f"Error storing credential {filename}: {e}")
return False
async def get_credential(self, filename: str, mode: str = "geminicli") -> Optional[Dict[str, Any]]:
"""获取凭证数据"""
self._ensure_initialized()
# 统一使用 basename 处理文件名
filename = os.path.basename(filename)
try:
collection_name = self._get_collection_name(mode)
collection = self._db[collection_name]
# 精确匹配,只投影需要的字段
doc = await collection.find_one(
{"filename": filename},
{"credential_data": 1, "_id": 0}
)
if doc:
return doc.get("credential_data")
return None
except Exception as e:
log.error(f"Error getting credential {filename}: {e}")
return None
async def list_credentials(self, mode: str = "geminicli") -> List[str]:
"""列出所有凭证文件名"""
self._ensure_initialized()
try:
collection_name = self._get_collection_name(mode)
collection = self._db[collection_name]
# 使用聚合管道
pipeline = [
{"$sort": {"rotation_order": 1}},
{"$project": {"filename": 1, "_id": 0}}
]
docs = await collection.aggregate(pipeline).to_list(length=None)
return [doc["filename"] for doc in docs]
except Exception as e:
log.error(f"Error listing credentials: {e}")
return []
async def delete_credential(self, filename: str, mode: str = "geminicli") -> bool:
"""删除凭证"""
self._ensure_initialized()
# 统一使用 basename 处理文件名
filename = os.path.basename(filename)
try:
collection_name = self._get_collection_name(mode)
collection = self._db[collection_name]
# 精确匹配删除
result = await collection.delete_one({"filename": filename})
deleted_count = result.deleted_count
if deleted_count > 0:
# 从 Redis 池中移除
await self._redis_remove_cred(mode, filename)
log.debug(f"Deleted {deleted_count} credential(s): {filename} (mode={mode})")
return True
else:
log.warning(f"No credential found to delete: {filename} (mode={mode})")
return False
except Exception as e:
log.error(f"Error deleting credential {filename}: {e}")
return False
async def get_duplicate_credentials_by_email(self, mode: str = "geminicli") -> Dict[str, Any]:
"""
获取按邮箱分组的重复凭证信息(只查询邮箱和文件名,不加载完整凭证数据)
用于去重操作
Args:
mode: 凭证模式 ("geminicli" 或 "antigravity")
Returns:
包含 email_groups(邮箱分组)、duplicate_count(重复数量)、no_email_count(无邮箱数量)的字典
"""
self._ensure_initialized()
try:
collection_name = self._get_collection_name(mode)
collection = self._db[collection_name]
# 使用聚合管道,只查询 filename 和 user_email 字段
pipeline = [
{
"$project": {
"filename": 1,
"user_email": 1,
"_id": 0
}
},
{
"$sort": {"filename": 1}
}
]
docs = await collection.aggregate(pipeline).to_list(length=None)
# 按邮箱分组
email_to_files = {}
no_email_files = []
for doc in docs:
filename = doc.get("filename")
user_email = doc.get("user_email")
if user_email:
if user_email not in email_to_files:
email_to_files[user_email] = []
email_to_files[user_email].append(filename)
else:
no_email_files.append(filename)
# 找出重复的邮箱组
duplicate_groups = []
total_duplicate_count = 0
for email, files in email_to_files.items():
if len(files) > 1:
# 保留第一个文件,其他为重复
duplicate_groups.append({
"email": email,
"kept_file": files[0],
"duplicate_files": files[1:],
"duplicate_count": len(files) - 1,
})
total_duplicate_count += len(files) - 1
return {
"email_groups": email_to_files,
"duplicate_groups": duplicate_groups,
"duplicate_count": total_duplicate_count,
"no_email_files": no_email_files,
"no_email_count": len(no_email_files),
"unique_email_count": len(email_to_files),
"total_count": len(docs),
}
except Exception as e:
log.error(f"Error getting duplicate credentials by email: {e}")
return {
"email_groups": {},
"duplicate_groups": [],
"duplicate_count": 0,
"no_email_files": [],
"no_email_count": 0,
"unique_email_count": 0,
"total_count": 0,
}
async def update_credential_state(
self, filename: str, state_updates: Dict[str, Any], mode: str = "geminicli"
) -> bool:
"""更新凭证状态"""
self._ensure_initialized()
# 统一使用 basename 处理文件名
filename = os.path.basename(filename)
try:
collection_name = self._get_collection_name(mode)
collection = self._db[collection_name]
# 过滤只更新状态字段
valid_updates = {
k: v for k, v in state_updates.items() if k in self.STATE_FIELDS
}
if mode != "antigravity":
valid_updates.pop("enable_credit", None)
if not valid_updates:
return True
valid_updates["updated_at"] = time.time()
# 精确匹配更新
result = await collection.update_one(
{"filename": filename}, {"$set": valid_updates}
)
updated_count = result.modified_count + result.matched_count
# 如果 disabled 发生变化,同步 Redis 池成员关系
if self._redis_enabled and "disabled" in valid_updates:
if valid_updates["disabled"]:
# 直接禁用:从集合中移除
await self._redis_remove_cred(mode, filename)
else:
# 重新启用:需要读取当前 tier/preview 以正确放入分桶
doc = await collection.find_one(
{"filename": filename},
projection={"tier": 1, "preview": 1, "_id": 0},
)
tier_val = (doc or {}).get("tier", "pro") or "pro"
preview_val = (doc or {}).get("preview", True)
await self._redis_sync_cred(mode, filename, disabled=False, tier=tier_val, preview=preview_val)
elif self._redis_enabled and ("tier" in valid_updates or "preview" in valid_updates):
# tier 或 preview 更新:重新同步分桶(只在凭证未禁用时)
doc = await collection.find_one(
{"filename": filename},
projection={"disabled": 1, "tier": 1, "preview": 1, "_id": 0},
)
if doc and not doc.get("disabled", False):
tier_val = doc.get("tier", "pro") or "pro"
preview_val = doc.get("preview", True)
await self._redis_sync_cred(mode, filename, disabled=False, tier=tier_val, preview=preview_val)
return updated_count > 0
except Exception as e:
log.error(f"Error updating credential state {filename}: {e}")
return False
async def get_credential_state(self, filename: str, mode: str = "geminicli") -> Dict[str, Any]:
"""获取凭证状态(不包含error_messages)"""
self._ensure_initialized()
# 统一使用 basename 处理文件名
filename = os.path.basename(filename)
try:
collection_name = self._get_collection_name(mode)
collection = self._db[collection_name]
current_time = time.time()
# 精确匹配
doc = await collection.find_one({"filename": filename})
if doc:
model_cooldowns = doc.get("model_cooldowns", {})
# 过滤掉损坏的数据(dict类型)和过期的冷却
if model_cooldowns:
model_cooldowns = {
k: v for k, v in model_cooldowns.items()
if isinstance(v, (int, float)) and v > current_time
}
state = {
"disabled": doc.get("disabled", False),
"error_codes": doc.get("error_codes", []),
"last_success": doc.get("last_success", current_time),
"user_email": doc.get("user_email"),
"model_cooldowns": model_cooldowns,
"preview": doc.get("preview", True),
"tier": doc.get("tier", "pro"),
}
if mode == "antigravity":
state["enable_credit"] = doc.get("enable_credit", False)
return state
# 返回默认状态
default_state = {
"disabled": False,
"error_codes": [],
"last_success": current_time,
"user_email": None,
"model_cooldowns": {},
"preview": True,
"tier": "pro",
}
if mode == "antigravity":
default_state["enable_credit"] = False
return default_state
except Exception as e:
log.error(f"Error getting credential state {filename}: {e}")
return {}
async def get_all_credential_states(self, mode: str = "geminicli") -> Dict[str, Dict[str, Any]]:
"""获取所有凭证状态(不包含error_messages)"""
self._ensure_initialized()
try:
collection_name = self._get_collection_name(mode)
collection = self._db[collection_name]
# 使用投影只获取需要的字段(不包含error_messages)
projection = {
"filename": 1,
"disabled": 1,
"error_codes": 1,
"last_success": 1,
"user_email": 1,
"model_cooldowns": 1,
"preview": 1,
"tier": 1,
"enable_credit": 1,
"_id": 0
}
cursor = collection.find({}, projection=projection)
states = {}
current_time = time.time()
async for doc in cursor:
filename = doc["filename"]
model_cooldowns = doc.get("model_cooldowns", {})
# 自动过滤掉已过期的模型CD
if model_cooldowns:
model_cooldowns = {
k: v for k, v in model_cooldowns.items()
if isinstance(v, (int, float)) and v > current_time
}
state = {
"disabled": doc.get("disabled", False),
"error_codes": doc.get("error_codes", []),
"last_success": doc.get("last_success", time.time()),
"user_email": doc.get("user_email"),
"model_cooldowns": model_cooldowns,
"preview": doc.get("preview", True),
"tier": doc.get("tier", "pro"),
}
if mode == "antigravity":
state["enable_credit"] = doc.get("enable_credit", False)
states[filename] = state
return states
except Exception as e:
log.error(f"Error getting all credential states: {e}")
return {}
async def get_credentials_summary(
self,
offset: int = 0,
limit: Optional[int] = None,
status_filter: str = "all",
mode: str = "geminicli",
error_code_filter: Optional[str] = None,
cooldown_filter: Optional[str] = None,
preview_filter: Optional[str] = None,
tier_filter: Optional[str] = None
) -> Dict[str, Any]:
"""
获取凭证的摘要信息(不包含完整凭证数据)- 支持分页和状态筛选
Args:
offset: 跳过的记录数(默认0)
limit: 返回的最大记录数(None表示返回所有)
status_filter: 状态筛选(all=全部, enabled=仅启用, disabled=仅禁用)
mode: 凭证模式 ("geminicli" 或 "antigravity")
error_code_filter: 错误码筛选(格式如"400"或"403",筛选包含该错误码的凭证)
cooldown_filter: 冷却状态筛选("in_cooldown"=冷却中, "no_cooldown"=未冷却)
preview_filter: Preview筛选("preview"=支持preview, "no_preview"=不支持preview,仅geminicli模式有效)
tier_filter: tier筛选("free", "pro", "ultra")
Returns:
包含 items(凭证列表)、total(总数)、offset、limit 的字典
"""
self._ensure_initialized()
try:
# 根据 mode 选择集合名
collection_name = self._get_collection_name(mode)
collection = self._db[collection_name]
# 构建查询条件
query = {}
if status_filter == "enabled":
query["disabled"] = False
elif status_filter == "disabled":
query["disabled"] = True
# 错误码筛选 - 兼容存储为数字或字符串的情况
if error_code_filter and str(error_code_filter).strip().lower() != "all":
filter_value = str(error_code_filter).strip()
query_values = [filter_value]
try:
query_values.append(int(filter_value))
except ValueError:
pass
query["error_codes"] = {"$in": query_values}
# 计算全局统计数据(不受筛选条件影响)
global_stats = {"total": 0, "normal": 0, "disabled": 0}
stats_pipeline = [
{
"$group": {
"_id": "$disabled",
"count": {"$sum": 1}
}
}
]
stats_result = await collection.aggregate(stats_pipeline).to_list(length=10)
for item in stats_result:
count = item["count"]
global_stats["total"] += count
if item["_id"]:
global_stats["disabled"] = count
else:
global_stats["normal"] = count
# 获取所有匹配的文档(用于冷却筛选,因为需要在Python中判断)
projection = {
"filename": 1,
"disabled": 1,
"error_codes": 1,
"last_success": 1,
"user_email": 1,
"rotation_order": 1,
"model_cooldowns": 1,
"preview": 1,
"tier": 1,
"enable_credit": 1,
"_id": 0
}
cursor = collection.find(query, projection=projection).sort("rotation_order", 1)
all_summaries = []
current_time = time.time()
async for doc in cursor:
model_cooldowns = doc.get("model_cooldowns", {})
# 自动过滤掉已过期的模型CD
active_cooldowns = {}
if model_cooldowns:
active_cooldowns = {
k: v for k, v in model_cooldowns.items()
if isinstance(v, (int, float)) and v > current_time
}
summary = {
"filename": doc["filename"],
"disabled": doc.get("disabled", False),
"error_codes": doc.get("error_codes", []),
"last_success": doc.get("last_success", current_time),
"user_email": doc.get("user_email"),
"rotation_order": doc.get("rotation_order", 0),
"model_cooldowns": active_cooldowns,
"preview": doc.get("preview", True),
"tier": doc.get("tier", "pro"),
}
if mode == "antigravity":
summary["enable_credit"] = bool(doc.get("enable_credit", False))
if mode == "geminicli" and preview_filter:
preview_value = summary.get("preview", True)
if preview_filter == "preview" and not preview_value:
continue
if preview_filter == "no_preview" and preview_value:
continue
# 应用tier筛选
if tier_filter and tier_filter in ("free", "pro", "ultra"):
if summary["tier"] != tier_filter:
continue
# 应用冷却筛选
if cooldown_filter == "in_cooldown":
# 只保留有冷却的凭证
if active_cooldowns:
all_summaries.append(summary)
elif cooldown_filter == "no_cooldown":
# 只保留没有冷却的凭证
if not active_cooldowns:
all_summaries.append(summary)
else:
# 不筛选冷却状态
all_summaries.append(summary)
# 应用分页
total_count = len(all_summaries)
if limit is not None:
summaries = all_summaries[offset:offset + limit]
else:
summaries = all_summaries[offset:]
return {
"items": summaries,
"total": total_count,
"offset": offset,
"limit": limit,
"stats": global_stats,
}
except Exception as e:
log.error(f"Error getting credentials summary: {e}")
return {
"items": [],
"total": 0,
"offset": offset,
"limit": limit,
"stats": {"total": 0, "normal": 0, "disabled": 0},
}
# ============ 配置管理(内存缓存 + 可选 Redis)============
def _rk_config(self, key: str) -> str:
"""配置项的 Redis key"""
return f"gcli:config:{key}"
def _rk_config_all(self) -> str:
"""所有配置的 Redis Hash key"""
return "gcli:config"
async def _load_config_to_redis(self) -> None:
"""将所有配置从 MongoDB 同步到 Redis Hash"""
if not self._redis_enabled:
return
try:
config_collection = self._db["config"]
cursor = config_collection.find({})
mapping = {}
async for doc in cursor:
mapping[doc["key"]] = json.dumps(doc.get("value"))
pipe = self._redis.pipeline()
pipe.delete(self._rk_config_all())
if mapping:
pipe.hset(self._rk_config_all(), mapping=mapping)
await pipe.execute()
log.debug(f"Synced {len(mapping)} config items to Redis")
except Exception as e:
log.warning(f"Failed to sync config to Redis: {e}")
async def set_config(self, key: str, value: Any) -> bool:
"""设置配置(写入数据库;Redis 启用时写 Redis,否则更新内存缓存)"""
self._ensure_initialized()
try:
config_collection = self._db["config"]
await config_collection.update_one(
{"key": key},
{"$set": {"value": value, "updated_at": time.time()}},
upsert=True,
)
if self._redis_enabled:
try:
await self._redis.hset(self._rk_config_all(), key, json.dumps(value))
except Exception as e:
log.warning(f"Redis config set error for key={key}: {e}")
else:
self._config_cache[key] = value
return True
except Exception as e:
log.error(f"Error setting config {key}: {e}")
return False
async def reload_config_cache(self):
"""重新加载配置缓存(在批量修改配置后调用)"""
self._ensure_initialized()
if self._redis_enabled:
await self._load_config_to_redis()
else:
self._config_loaded = False
await self._load_config_cache()
log.info("Config cache reloaded from database")
async def get_config(self, key: str, default: Any = None) -> Any:
"""获取配置(Redis 启用时从 Redis 读取,否则从内存缓存)"""
self._ensure_initialized()
if self._redis_enabled:
try:
raw = await self._redis.hget(self._rk_config_all(), key)
if raw is not None:
return json.loads(raw)
return default
except Exception as e:
log.warning(f"Redis config get error for key={key}: {e}")
return default
return self._config_cache.get(key, default)
async def get_all_config(self) -> Dict[str, Any]:
"""获取所有配置(Redis 启用时从 Redis 读取,否则从内存缓存)"""
self._ensure_initialized()
if self._redis_enabled:
try:
raw_map = await self._redis.hgetall(self._rk_config_all())
return {k: json.loads(v) for k, v in raw_map.items()}
except Exception as e:
log.warning(f"Redis config getall error: {e}")
return {}
return self._config_cache.copy()
async def delete_config(self, key: str) -> bool:
"""删除配置"""
self._ensure_initialized()
try:
config_collection = self._db["config"]
result = await config_collection.delete_one({"key": key})
if self._redis_enabled:
try:
await self._redis.hdel(self._rk_config_all(), key)
except Exception as e:
log.warning(f"Redis config delete error for key={key}: {e}")
else:
self._config_cache.pop(key, None)
return result.deleted_count > 0
except Exception as e:
log.error(f"Error deleting config {key}: {e}")
return False
async def get_credential_errors(self, filename: str, mode: str = "geminicli") -> Dict[str, Any]:
"""
专门获取凭证的错误信息(包含 error_codes 和 error_messages)
Args:
filename: 凭证文件名
mode: 凭证模式 ("geminicli" 或 "antigravity")
Returns:
包含 error_codes 和 error_messages 的字典
"""
self._ensure_initialized()
# 统一使用 basename 处理文件名
filename = os.path.basename(filename)
try:
collection_name = self._get_collection_name(mode)
collection = self._db[collection_name]
# 精确匹配
doc = await collection.find_one(
{"filename": filename},
{"error_codes": 1, "error_messages": 1, "_id": 0}
)
if doc:
return {
"filename": filename,
"error_codes": doc.get("error_codes", []),
"error_messages": doc.get("error_messages", []),
}
# 凭证不存在,返回空错误信息
return {
"filename": filename,
"error_codes": [],
"error_messages": [],
}
except Exception as e:
log.error(f"Error getting credential errors {filename}: {e}")
return {
"filename": filename,
"error_codes": [],
"error_messages": [],
"error": str(e)
}
# ============ 模型级冷却管理 ============
async def set_model_cooldown(
self,
filename: str,
model_name: str,
cooldown_until: Optional[float],
mode: str = "geminicli"
) -> bool:
"""
设置特定模型的冷却时间
Args:
filename: 凭证文件名
model_name: 模型名(完整模型名,如 "gemini-2.0-flash-exp")
cooldown_until: 冷却截止时间戳(None 表示清除冷却)
mode: 凭证模式 ("geminicli" 或 "antigravity")
Returns:
是否成功
"""
self._ensure_initialized()
# 统一使用 basename 处理文件名
filename = os.path.basename(filename)
try:
collection_name = self._get_collection_name(mode)
collection = self._db[collection_name]
# 转义模型名中的点号
escaped_model_name = self._escape_model_name(model_name)
# 使用原子操作直接更新,避免竞态条件
if cooldown_until is None:
# 删除指定模型的冷却
result = await collection.update_one(
{"filename": filename},
{
"$unset": {f"model_cooldowns.{escaped_model_name}": ""},
"$set": {"updated_at": time.time()}
}
)
else:
# 设置冷却时间
result = await collection.update_one(
{"filename": filename},
{
"$set": {
f"model_cooldowns.{escaped_model_name}": cooldown_until,
"updated_at": time.time()
}
}
)
if result.matched_count == 0:
log.warning(f"Credential {filename} not found")
return False
# 同步写入 Redis TTL key
if self._redis_enabled:
cd_key = self._rk_cd(mode, filename, escaped_model_name)
if cooldown_until is None:
await self._redis.delete(cd_key)
else:
ttl = int(cooldown_until - time.time())
if ttl > 0:
await self._redis.setex(cd_key, ttl, str(cooldown_until))
else:
# 冷却已经过期,确保清除
await self._redis.delete(cd_key)
log.debug(f"Set model cooldown: {filename}, model_name={model_name}, cooldown_until={cooldown_until}")
return True
except Exception as e:
log.error(f"Error setting model cooldown for {filename}: {e}")
return False
async def clear_all_model_cooldowns(
self,
filename: str,
mode: str = "geminicli"
) -> bool:
"""
清除某个凭证的所有模型冷却时间
Args:
filename: 凭证文件名
mode: 凭证模式 ("geminicli" 或 "antigravity")
Returns:
是否成功
"""
self._ensure_initialized()
filename = os.path.basename(filename)
try:
collection_name = self._get_collection_name(mode)
collection = self._db[collection_name]
doc = await collection.find_one(
{"filename": filename},
{"model_cooldowns": 1, "_id": 0}
)
if not doc:
log.warning(f"Credential {filename} not found")
return False
model_cooldowns = doc.get("model_cooldowns") or {}
await collection.update_one(
{"filename": filename},
{
"$set": {
"model_cooldowns": {},
"updated_at": time.time(),
}
}
)
if self._redis_enabled and isinstance(model_cooldowns, dict) and model_cooldowns:
redis_keys = [self._rk_cd(mode, filename, escaped_model) for escaped_model in model_cooldowns.keys()]
await self._redis.delete(*redis_keys)
log.debug(f"Cleared all model cooldowns: {filename} (mode={mode})")
return True
except Exception as e:
log.error(f"Error clearing all model cooldowns for {filename}: {e}")
return False
async def record_success(
self,
filename: str,
model_name: Optional[str] = None,
mode: str = "geminicli"
) -> None:
"""
成功调用后的条件写入:
- 只有当前 error_codes 非空时才清除错误并写 last_success
- 只有当前存在该模型的冷却键时才清除
通过 MongoDB 服务端条件匹配实现
"""
self._ensure_initialized()
filename = os.path.basename(filename)
try:
collection_name = self._get_collection_name(mode)
collection = self._db[collection_name]
now = time.time()
# 条件写入:只有 error_codes 非空时才触发,避免无意义的写 IO
await collection.update_one(
{"filename": filename, "error_codes": {"$ne": []}},
{"$set": {
"last_success": now,
"error_codes": [],
"error_messages": {},
"updated_at": now,
}}
)
# 条件删除模型冷却:只有该键存在时才写入
if model_name:
escaped = self._escape_model_name(model_name)
await collection.update_one(
{"filename": filename, f"model_cooldowns.{escaped}": {"$exists": True}},
{"$unset": {f"model_cooldowns.{escaped}": ""}, "$set": {"updated_at": now}}
)
# 同步删除 Redis 冷却 key
if self._redis_enabled:
await self._redis.delete(self._rk_cd(mode, filename, escaped))
except Exception as e:
log.error(f"Error recording success for {filename}: {e}")