""" PostgreSQL 存储管理器 """ import asyncio import json import os import time from typing import Any, Dict, List, Optional, Tuple import asyncpg from log import log class PSQLManager: """PostgreSQL 数据库管理器""" # 状态字段常量 STATE_FIELDS = { "error_codes", "error_messages", "disabled", "last_success", "user_email", "model_cooldowns", "preview", "tier", "enable_credit", } def __init__(self): self._dsn: Optional[str] = None self._pool: Optional[asyncpg.Pool] = None self._initialized = False self._lock = asyncio.Lock() # 内存配置缓存 self._config_cache: Dict[str, Any] = {} self._config_loaded = False async def initialize(self) -> None: """初始化 PostgreSQL 数据库""" if self._initialized: return async with self._lock: if self._initialized: return try: self._dsn = os.getenv("POSTGRESQL_URI", "") if not self._dsn: raise RuntimeError("POSTGRESQL_URI environment variable is not set") self._pool = await asyncpg.create_pool(self._dsn, min_size=2, max_size=10) async with self._pool.acquire() as conn: await self._create_tables(conn) await self._ensure_schema_compatibility(conn) await self._load_config_cache() self._initialized = True log.info("PostgreSQL storage initialized") except Exception as e: log.error(f"Error initializing PostgreSQL: {e}") if self._pool: await self._pool.close() self._pool = None raise async def _create_tables(self, conn: asyncpg.Connection) -> None: """创建数据库表和索引""" await conn.execute(""" CREATE TABLE IF NOT EXISTS credentials ( id SERIAL PRIMARY KEY, filename TEXT UNIQUE NOT NULL, credential_data TEXT NOT NULL, disabled INTEGER DEFAULT 0, error_codes TEXT DEFAULT '[]', error_messages TEXT DEFAULT '[]', last_success DOUBLE PRECISION, user_email TEXT, model_cooldowns TEXT DEFAULT '{}', preview INTEGER DEFAULT 1, tier TEXT DEFAULT 'pro', rotation_order INTEGER DEFAULT 0, call_count INTEGER DEFAULT 0, created_at DOUBLE PRECISION DEFAULT EXTRACT(EPOCH FROM NOW()), updated_at DOUBLE PRECISION DEFAULT EXTRACT(EPOCH FROM NOW()) ) """) await conn.execute(""" CREATE TABLE IF NOT EXISTS antigravity_credentials ( id SERIAL PRIMARY KEY, filename TEXT UNIQUE NOT NULL, credential_data TEXT NOT NULL, disabled INTEGER DEFAULT 0, error_codes TEXT DEFAULT '[]', error_messages TEXT DEFAULT '[]', last_success DOUBLE PRECISION, user_email TEXT, model_cooldowns TEXT DEFAULT '{}', tier TEXT DEFAULT 'pro', enable_credit INTEGER DEFAULT 0, rotation_order INTEGER DEFAULT 0, call_count INTEGER DEFAULT 0, created_at DOUBLE PRECISION DEFAULT EXTRACT(EPOCH FROM NOW()), updated_at DOUBLE PRECISION DEFAULT EXTRACT(EPOCH FROM NOW()) ) """) await conn.execute(""" CREATE TABLE IF NOT EXISTS config ( key TEXT PRIMARY KEY, value TEXT NOT NULL, updated_at DOUBLE PRECISION DEFAULT EXTRACT(EPOCH FROM NOW()) ) """) # 索引 await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_disabled ON credentials(disabled) """) await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_rotation_order ON credentials(rotation_order) """) await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_ag_disabled ON antigravity_credentials(disabled) """) await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_ag_rotation_order ON antigravity_credentials(rotation_order) """) log.debug("PostgreSQL tables and indexes created") async def _ensure_schema_compatibility(self, conn: asyncpg.Connection) -> None: """确保数据库结构兼容,自动修复缺失的列""" required_columns = { "credentials": [ ("disabled", "INTEGER DEFAULT 0"), ("error_codes", "TEXT DEFAULT '[]'"), ("error_messages", "TEXT DEFAULT '[]'"), ("last_success", "DOUBLE PRECISION"), ("user_email", "TEXT"), ("model_cooldowns", "TEXT DEFAULT '{}'"), ("preview", "INTEGER DEFAULT 1"), ("tier", "TEXT DEFAULT 'pro'"), ("rotation_order", "INTEGER DEFAULT 0"), ("call_count", "INTEGER DEFAULT 0"), ("created_at", "DOUBLE PRECISION DEFAULT EXTRACT(EPOCH FROM NOW())"), ("updated_at", "DOUBLE PRECISION DEFAULT EXTRACT(EPOCH FROM NOW())"), ], "antigravity_credentials": [ ("disabled", "INTEGER DEFAULT 0"), ("error_codes", "TEXT DEFAULT '[]'"), ("error_messages", "TEXT DEFAULT '[]'"), ("last_success", "DOUBLE PRECISION"), ("user_email", "TEXT"), ("model_cooldowns", "TEXT DEFAULT '{}'"), ("tier", "TEXT DEFAULT 'pro'"), ("enable_credit", "INTEGER DEFAULT 0"), ("rotation_order", "INTEGER DEFAULT 0"), ("call_count", "INTEGER DEFAULT 0"), ("created_at", "DOUBLE PRECISION DEFAULT EXTRACT(EPOCH FROM NOW())"), ("updated_at", "DOUBLE PRECISION DEFAULT EXTRACT(EPOCH FROM NOW())"), ], } try: for table_name, columns in required_columns.items(): rows = await conn.fetch(""" SELECT column_name FROM information_schema.columns WHERE table_name = $1 """, table_name) existing = {r["column_name"] for r in rows} for col_name, col_def in columns: if col_name not in existing: try: await conn.execute( f"ALTER TABLE {table_name} ADD COLUMN {col_name} {col_def}" ) log.info(f"Added missing column {table_name}.{col_name}") except Exception as e: log.error(f"Failed to add column {table_name}.{col_name}: {e}") except Exception as e: log.error(f"Error ensuring schema compatibility: {e}") async def _load_config_cache(self) -> None: """加载配置到内存缓存""" if self._config_loaded: return try: async with self._pool.acquire() as conn: rows = await conn.fetch("SELECT key, value FROM config") for row in rows: try: self._config_cache[row["key"]] = json.loads(row["value"]) except json.JSONDecodeError: self._config_cache[row["key"]] = row["value"] self._config_loaded = True log.debug(f"Loaded {len(self._config_cache)} config items into cache") except Exception as e: log.error(f"Error loading config cache: {e}") self._config_cache = {} async def close(self) -> None: """关闭数据库连接池""" if self._pool: await self._pool.close() self._pool = None self._initialized = False log.debug("PostgreSQL storage closed") def _ensure_initialized(self) -> None: if not self._initialized or not self._pool: raise RuntimeError("PostgreSQL manager not initialized") def _get_table_name(self, mode: str) -> str: if mode == "antigravity": return "antigravity_credentials" elif mode == "geminicli": return "credentials" else: raise ValueError(f"Invalid mode: {mode}. Must be 'geminicli' or 'antigravity'") # ============ 凭证查询方法 ============ async def get_next_available_credential( self, mode: str = "geminicli", model_name: Optional[str] = None ) -> Optional[Tuple[str, Dict[str, Any]]]: """随机获取一个可用凭证(负载均衡)""" self._ensure_initialized() try: table_name = self._get_table_name(mode) current_time = time.time() async with self._pool.acquire() as conn: if mode == "geminicli": tier_clause = "" if model_name and "pro" in model_name.lower(): tier_clause = "AND (tier IS NULL OR tier != 'free')" rows = await conn.fetch(f""" SELECT filename, credential_data, model_cooldowns, preview FROM {table_name} WHERE disabled = 0 {tier_clause} ORDER BY RANDOM() """) if not model_name: if rows: return rows[0]["filename"], json.loads(rows[0]["credential_data"]) return None is_preview_model = "preview" in model_name.lower() non_preview_creds = [] preview_creds = [] for row in rows: model_cooldowns = json.loads(row["model_cooldowns"] or "{}") cd = model_cooldowns.get(model_name) if cd is None or current_time >= cd: if row["preview"]: preview_creds.append((row["filename"], row["credential_data"])) else: non_preview_creds.append((row["filename"], row["credential_data"])) if is_preview_model: if preview_creds: return preview_creds[0][0], json.loads(preview_creds[0][1]) else: if non_preview_creds: return non_preview_creds[0][0], json.loads(non_preview_creds[0][1]) elif preview_creds: return preview_creds[0][0], json.loads(preview_creds[0][1]) return None else: rows = await conn.fetch(f""" SELECT filename, credential_data, model_cooldowns, enable_credit FROM {table_name} WHERE disabled = 0 ORDER BY RANDOM() """) if not model_name: if rows: credential_data = json.loads(rows[0]["credential_data"]) credential_data["enable_credit"] = bool(rows[0]["enable_credit"]) return rows[0]["filename"], credential_data return None for row in rows: model_cooldowns = json.loads(row["model_cooldowns"] or "{}") cd = model_cooldowns.get(model_name) if cd is None or current_time >= cd: credential_data = json.loads(row["credential_data"]) credential_data["enable_credit"] = bool(row["enable_credit"]) return row["filename"], credential_data return None except Exception as e: log.error(f"Error getting next available credential (mode={mode}, model_name={model_name}): {e}") return None async def get_available_credentials_list(self) -> List[str]: """获取所有可用凭证列表""" self._ensure_initialized() try: async with self._pool.acquire() as conn: rows = await conn.fetch(""" SELECT filename FROM credentials WHERE disabled = 0 ORDER BY rotation_order ASC """) return [r["filename"] for r in rows] except Exception as e: log.error(f"Error getting available credentials list: {e}") return [] # ============ StorageBackend 协议方法 ============ async def store_credential(self, filename: str, credential_data: Dict[str, Any], mode: str = "geminicli") -> bool: """存储或更新凭证""" self._ensure_initialized() filename = os.path.basename(filename) try: table_name = self._get_table_name(mode) async with self._pool.acquire() as conn: existing = await conn.fetchrow( f"SELECT rotation_order FROM {table_name} WHERE filename = $1", filename ) if existing: await conn.execute( f""" UPDATE {table_name} SET credential_data = $1, updated_at = EXTRACT(EPOCH FROM NOW()) WHERE filename = $2 """, json.dumps(credential_data), filename ) else: row = await conn.fetchrow( f"SELECT COALESCE(MAX(rotation_order), -1) + 1 AS next_order FROM {table_name}" ) next_order = row["next_order"] await conn.execute( f""" INSERT INTO {table_name} (filename, credential_data, rotation_order, last_success) VALUES ($1, $2, $3, $4) """, filename, json.dumps(credential_data), next_order, time.time() ) log.debug(f"Stored credential: {filename} (mode={mode})") return True except Exception as e: log.error(f"Error storing credential {filename}: {e}") return False async def get_credential(self, filename: str, mode: str = "geminicli") -> Optional[Dict[str, Any]]: """获取凭证数据""" self._ensure_initialized() filename = os.path.basename(filename) try: table_name = self._get_table_name(mode) async with self._pool.acquire() as conn: row = await conn.fetchrow( f"SELECT credential_data FROM {table_name} WHERE filename = $1", filename ) if row: return json.loads(row["credential_data"]) return None except Exception as e: log.error(f"Error getting credential {filename}: {e}") return None async def list_credentials(self, mode: str = "geminicli") -> List[str]: """列出所有凭证文件名(包括禁用的)""" self._ensure_initialized() try: table_name = self._get_table_name(mode) async with self._pool.acquire() as conn: rows = await conn.fetch( f"SELECT filename FROM {table_name} ORDER BY rotation_order" ) return [r["filename"] for r in rows] except Exception as e: log.error(f"Error listing credentials: {e}") return [] async def delete_credential(self, filename: str, mode: str = "geminicli") -> bool: """删除凭证""" self._ensure_initialized() filename = os.path.basename(filename) try: table_name = self._get_table_name(mode) async with self._pool.acquire() as conn: result = await conn.execute( f"DELETE FROM {table_name} WHERE filename = $1", filename ) # asyncpg returns "DELETE N" deleted_count = int(result.split()[-1]) if deleted_count > 0: log.debug(f"Deleted credential: {filename} (mode={mode})") return True else: log.warning(f"No credential found to delete: {filename} (mode={mode})") return False except Exception as e: log.error(f"Error deleting credential {filename}: {e}") return False async def update_credential_state(self, filename: str, state_updates: Dict[str, Any], mode: str = "geminicli") -> bool: """更新凭证状态""" self._ensure_initialized() filename = os.path.basename(filename) try: table_name = self._get_table_name(mode) log.debug(f"[DB] update_credential_state: filename={filename}, updates={state_updates}, mode={mode}") set_clauses = [] values = [] idx = 1 for key, value in state_updates.items(): if key in self.STATE_FIELDS: if key == "enable_credit" and mode != "antigravity": continue if key in ("error_codes", "error_messages", "model_cooldowns"): set_clauses.append(f"{key} = ${idx}") values.append(json.dumps(value)) else: set_clauses.append(f"{key} = ${idx}") values.append(value) idx += 1 if not set_clauses: return True set_clauses.append(f"updated_at = EXTRACT(EPOCH FROM NOW())") values.append(filename) sql = f""" UPDATE {table_name} SET {', '.join(set_clauses)} WHERE filename = ${idx} """ async with self._pool.acquire() as conn: result = await conn.execute(sql, *values) updated_count = int(result.split()[-1]) return updated_count > 0 except Exception as e: log.error(f"[DB] Error updating credential state {filename}: {e}") return False async def get_credential_state(self, filename: str, mode: str = "geminicli") -> Dict[str, Any]: """获取凭证状态""" self._ensure_initialized() filename = os.path.basename(filename) try: table_name = self._get_table_name(mode) async with self._pool.acquire() as conn: if mode == "geminicli": row = await conn.fetchrow(f""" SELECT disabled, error_codes, last_success, user_email, model_cooldowns, preview, tier FROM {table_name} WHERE filename = $1 """, filename) if row: return { "disabled": bool(row["disabled"]), "error_codes": json.loads(row["error_codes"] or "[]"), "last_success": row["last_success"] or time.time(), "user_email": row["user_email"], "model_cooldowns": json.loads(row["model_cooldowns"] or "{}"), "preview": bool(row["preview"]) if row["preview"] is not None else True, "tier": row["tier"] if row["tier"] is not None else "pro", } return { "disabled": False, "error_codes": [], "last_success": time.time(), "user_email": None, "model_cooldowns": {}, "preview": True, "tier": "pro", } else: row = await conn.fetchrow(f""" SELECT disabled, error_codes, last_success, user_email, model_cooldowns, tier, enable_credit FROM {table_name} WHERE filename = $1 """, filename) if row: return { "disabled": bool(row["disabled"]), "error_codes": json.loads(row["error_codes"] or "[]"), "last_success": row["last_success"] or time.time(), "user_email": row["user_email"], "model_cooldowns": json.loads(row["model_cooldowns"] or "{}"), "tier": row["tier"] if row["tier"] is not None else "pro", "enable_credit": bool(row["enable_credit"]) if row["enable_credit"] is not None else False, } return { "disabled": False, "error_codes": [], "last_success": time.time(), "user_email": None, "model_cooldowns": {}, "tier": "pro", "enable_credit": False, } except Exception as e: log.error(f"Error getting credential state {filename}: {e}") return {} async def get_all_credential_states(self, mode: str = "geminicli") -> Dict[str, Dict[str, Any]]: """获取所有凭证状态""" self._ensure_initialized() try: table_name = self._get_table_name(mode) current_time = time.time() async with self._pool.acquire() as conn: if mode == "geminicli": rows = await conn.fetch(f""" SELECT filename, disabled, error_codes, last_success, user_email, model_cooldowns, preview, tier FROM {table_name} """) states = {} for row in rows: model_cooldowns = json.loads(row["model_cooldowns"] or "{}") if model_cooldowns: model_cooldowns = {k: v for k, v in model_cooldowns.items() if v > current_time} states[row["filename"]] = { "disabled": bool(row["disabled"]), "error_codes": json.loads(row["error_codes"] or "[]"), "last_success": row["last_success"] or current_time, "user_email": row["user_email"], "model_cooldowns": model_cooldowns, "preview": bool(row["preview"]) if row["preview"] is not None else True, "tier": row["tier"] if row["tier"] is not None else "pro", } return states else: rows = await conn.fetch(f""" SELECT filename, disabled, error_codes, last_success, user_email, model_cooldowns, tier, enable_credit FROM {table_name} """) states = {} for row in rows: model_cooldowns = json.loads(row["model_cooldowns"] or "{}") if model_cooldowns: model_cooldowns = {k: v for k, v in model_cooldowns.items() if v > current_time} states[row["filename"]] = { "disabled": bool(row["disabled"]), "error_codes": json.loads(row["error_codes"] or "[]"), "last_success": row["last_success"] or current_time, "user_email": row["user_email"], "model_cooldowns": model_cooldowns, "tier": row["tier"] if row["tier"] is not None else "pro", "enable_credit": bool(row["enable_credit"]) if row["enable_credit"] is not None else False, } return states except Exception as e: log.error(f"Error getting all credential states: {e}") return {} async def get_credentials_summary( self, offset: int = 0, limit: Optional[int] = None, status_filter: str = "all", mode: str = "geminicli", error_code_filter: Optional[str] = None, cooldown_filter: Optional[str] = None, preview_filter: Optional[str] = None, tier_filter: Optional[str] = None ) -> Dict[str, Any]: """获取凭证的摘要信息,支持分页和状态筛选""" self._ensure_initialized() try: table_name = self._get_table_name(mode) current_time = time.time() async with self._pool.acquire() as conn: # 全局统计 stats_rows = await conn.fetch( f"SELECT disabled, COUNT(*) AS cnt FROM {table_name} GROUP BY disabled" ) global_stats = {"total": 0, "normal": 0, "disabled": 0} for r in stats_rows: global_stats["total"] += r["cnt"] if r["disabled"]: global_stats["disabled"] = r["cnt"] else: global_stats["normal"] = r["cnt"] # WHERE 子句 where_clauses = [] if status_filter == "enabled": where_clauses.append("disabled = 0") elif status_filter == "disabled": where_clauses.append("disabled = 1") where_clause = ("WHERE " + " AND ".join(where_clauses)) if where_clauses else "" # 查询 if mode == "geminicli": all_rows = await conn.fetch(f""" SELECT filename, disabled, error_codes, last_success, user_email, rotation_order, model_cooldowns, preview, tier FROM {table_name} {where_clause} ORDER BY rotation_order """) else: all_rows = await conn.fetch(f""" SELECT filename, disabled, error_codes, last_success, user_email, rotation_order, model_cooldowns, tier, enable_credit FROM {table_name} {where_clause} ORDER BY rotation_order """) # 错误码筛选 filter_value = None filter_int = None if error_code_filter and str(error_code_filter).strip().lower() != "all": filter_value = str(error_code_filter).strip() try: filter_int = int(filter_value) except ValueError: filter_int = None all_summaries = [] for row in all_rows: error_codes_json = row["error_codes"] or "[]" model_cooldowns = json.loads(row["model_cooldowns"] or "{}") active_cooldowns = {k: v for k, v in model_cooldowns.items() if v > current_time} error_codes = json.loads(error_codes_json) if filter_value: match = False for code in error_codes: if code == filter_value or code == filter_int: match = True break if isinstance(code, str) and filter_int is not None: try: if int(code) == filter_int: match = True break except ValueError: pass if not match: continue summary = { "filename": row["filename"], "disabled": bool(row["disabled"]), "error_codes": error_codes, "last_success": row["last_success"] or current_time, "user_email": row["user_email"], "rotation_order": row["rotation_order"], "model_cooldowns": active_cooldowns, "tier": row["tier"] if row["tier"] is not None else "pro", } if mode == "geminicli": summary["preview"] = bool(row["preview"]) if row["preview"] is not None else True if preview_filter: preview_value = summary.get("preview", True) if preview_filter == "preview" and not preview_value: continue elif preview_filter == "no_preview" and preview_value: continue else: summary["enable_credit"] = bool(row["enable_credit"]) if row["enable_credit"] is not None else False if tier_filter and tier_filter in ("free", "pro", "ultra"): if summary["tier"] != tier_filter: continue if cooldown_filter == "in_cooldown": if active_cooldowns: all_summaries.append(summary) elif cooldown_filter == "no_cooldown": if not active_cooldowns: all_summaries.append(summary) else: all_summaries.append(summary) total_count = len(all_summaries) if limit is not None: summaries = all_summaries[offset:offset + limit] else: summaries = all_summaries[offset:] return { "items": summaries, "total": total_count, "offset": offset, "limit": limit, "stats": global_stats, } except Exception as e: log.error(f"Error getting credentials summary: {e}") return { "items": [], "total": 0, "offset": offset, "limit": limit, "stats": {"total": 0, "normal": 0, "disabled": 0}, } async def get_duplicate_credentials_by_email(self, mode: str = "geminicli") -> Dict[str, Any]: """获取按邮箱分组的重复凭证信息""" self._ensure_initialized() try: table_name = self._get_table_name(mode) async with self._pool.acquire() as conn: rows = await conn.fetch( f"SELECT filename, user_email FROM {table_name} ORDER BY filename" ) email_to_files: Dict[str, List[str]] = {} no_email_files: List[str] = [] for row in rows: if row["user_email"]: email_to_files.setdefault(row["user_email"], []).append(row["filename"]) else: no_email_files.append(row["filename"]) duplicate_groups = [] total_duplicate_count = 0 for email, files in email_to_files.items(): if len(files) > 1: duplicate_groups.append({ "email": email, "kept_file": files[0], "duplicate_files": files[1:], "duplicate_count": len(files) - 1, }) total_duplicate_count += len(files) - 1 return { "email_groups": email_to_files, "duplicate_groups": duplicate_groups, "duplicate_count": total_duplicate_count, "no_email_files": no_email_files, "no_email_count": len(no_email_files), "unique_email_count": len(email_to_files), "total_count": len(rows), } except Exception as e: log.error(f"Error getting duplicate credentials by email: {e}") return { "email_groups": {}, "duplicate_groups": [], "duplicate_count": 0, "no_email_files": [], "no_email_count": 0, "unique_email_count": 0, "total_count": 0, } # ============ 配置管理(内存缓存)============ async def set_config(self, key: str, value: Any) -> bool: """设置配置(写入数据库 + 更新内存缓存)""" self._ensure_initialized() try: async with self._pool.acquire() as conn: await conn.execute(""" INSERT INTO config (key, value, updated_at) VALUES ($1, $2, EXTRACT(EPOCH FROM NOW())) ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value, updated_at = EXCLUDED.updated_at """, key, json.dumps(value)) self._config_cache[key] = value return True except Exception as e: log.error(f"Error setting config {key}: {e}") return False async def reload_config_cache(self) -> None: """重新加载配置缓存""" self._ensure_initialized() self._config_loaded = False await self._load_config_cache() log.info("Config cache reloaded from database") async def get_config(self, key: str, default: Any = None) -> Any: """获取配置(从内存缓存)""" self._ensure_initialized() return self._config_cache.get(key, default) async def get_all_config(self) -> Dict[str, Any]: """获取所有配置(从内存缓存)""" self._ensure_initialized() return self._config_cache.copy() async def delete_config(self, key: str) -> bool: """删除配置""" self._ensure_initialized() try: async with self._pool.acquire() as conn: await conn.execute("DELETE FROM config WHERE key = $1", key) self._config_cache.pop(key, None) return True except Exception as e: log.error(f"Error deleting config {key}: {e}") return False async def get_credential_errors(self, filename: str, mode: str = "geminicli") -> Dict[str, Any]: """获取凭证的错误信息""" self._ensure_initialized() filename = os.path.basename(filename) try: table_name = self._get_table_name(mode) async with self._pool.acquire() as conn: row = await conn.fetchrow( f"SELECT error_codes, error_messages FROM {table_name} WHERE filename = $1", filename ) if row: return { "filename": filename, "error_codes": json.loads(row["error_codes"] or "[]"), "error_messages": json.loads(row["error_messages"] or "[]"), } return {"filename": filename, "error_codes": [], "error_messages": []} except Exception as e: log.error(f"Error getting credential errors {filename}: {e}") return {"filename": filename, "error_codes": [], "error_messages": [], "error": str(e)} # ============ 模型级冷却管理 ============ async def set_model_cooldown( self, filename: str, model_name: str, cooldown_until: Optional[float], mode: str = "geminicli" ) -> bool: """设置特定模型的冷却时间""" self._ensure_initialized() filename = os.path.basename(filename) try: table_name = self._get_table_name(mode) async with self._pool.acquire() as conn: row = await conn.fetchrow( f"SELECT model_cooldowns FROM {table_name} WHERE filename = $1", filename ) if not row: log.warning(f"Credential {filename} not found") return False model_cooldowns = json.loads(row["model_cooldowns"] or "{}") if cooldown_until is None: model_cooldowns.pop(model_name, None) else: model_cooldowns[model_name] = cooldown_until await conn.execute( f""" UPDATE {table_name} SET model_cooldowns = $1, updated_at = EXTRACT(EPOCH FROM NOW()) WHERE filename = $2 """, json.dumps(model_cooldowns), filename ) log.debug(f"Set model cooldown: {filename}, model_name={model_name}, cooldown_until={cooldown_until}") return True except Exception as e: log.error(f"Error setting model cooldown for {filename}: {e}") return False async def clear_all_model_cooldowns( self, filename: str, mode: str = "geminicli" ) -> bool: """清除某个凭证的所有模型冷却时间""" self._ensure_initialized() filename = os.path.basename(filename) try: table_name = self._get_table_name(mode) async with self._pool.acquire() as conn: result = await conn.execute( f""" UPDATE {table_name} SET model_cooldowns = '{{}}', updated_at = EXTRACT(EPOCH FROM NOW()) WHERE filename = $1 """, filename, ) updated_count = int(result.split()[-1]) if updated_count == 0: log.warning(f"Credential {filename} not found") return False log.debug(f"Cleared all model cooldowns: {filename} (mode={mode})") return True except Exception as e: log.error(f"Error clearing all model cooldowns for {filename}: {e}") return False async def record_success( self, filename: str, model_name: Optional[str] = None, mode: str = "geminicli" ) -> None: """成功调用后的条件写入""" self._ensure_initialized() filename = os.path.basename(filename) try: table_name = self._get_table_name(mode) async with self._pool.acquire() as conn: await conn.execute(f""" UPDATE {table_name} SET last_success = EXTRACT(EPOCH FROM NOW()), error_codes = '[]', error_messages = '{{}}', updated_at = EXTRACT(EPOCH FROM NOW()) WHERE filename = $1 AND (error_codes IS NOT NULL AND error_codes != '[]' AND error_codes != '') """, filename) if model_name: row = await conn.fetchrow( f"SELECT model_cooldowns FROM {table_name} WHERE filename = $1", filename ) if row: cooldowns = json.loads(row["model_cooldowns"] or "{}") if model_name in cooldowns: cooldowns.pop(model_name) await conn.execute( f""" UPDATE {table_name} SET model_cooldowns = $1, updated_at = EXTRACT(EPOCH FROM NOW()) WHERE filename = $2 """, json.dumps(cooldowns), filename ) except Exception as e: log.error(f"Error recording success for {filename}: {e}")