"""数据库抽象层 - 支持文件系统和远程SQL数据库,以及双向同步""" import os import json import asyncio import hashlib from abc import ABC, abstractmethod from typing import List, Dict, Any, Optional, Callable from pathlib import Path import logging import time logger = logging.getLogger(__name__) class DatabaseInterface(ABC): """数据库接口抽象类""" @abstractmethod async def save_accounts(self, accounts: List[Dict[str, Any]]) -> bool: """保存账号配置""" pass @abstractmethod async def load_accounts(self) -> List[Dict[str, Any]]: """加载账号配置""" pass @abstractmethod async def save_config(self, config: Dict[str, Any]) -> bool: """保存完整配置""" pass @abstractmethod async def load_config(self) -> Dict[str, Any]: """加载完整配置""" pass @abstractmethod async def save_admin_config(self, admin_config: Dict[str, Any]) -> bool: """保存管理员配置""" pass @abstractmethod async def load_admin_config(self) -> Dict[str, Any]: """加载管理员配置""" pass @abstractmethod async def initialize(self) -> bool: """初始化数据库""" pass @abstractmethod async def get_config_hash(self) -> str: """获取配置哈希值用于变更检测""" pass class FileSystemDatabase(DatabaseInterface): """文件系统数据库实现(原有逻辑)""" def __init__(self, data_dir: Path): self.data_dir = data_dir self.config_file = data_dir / "config.json" self.admin_config_file = data_dir / "admin.json" def _ensure_dir(self): """确保目录存在""" self.data_dir.mkdir(parents=True, exist_ok=True) async def initialize(self) -> bool: """初始化文件系统""" try: self._ensure_dir() return True except Exception as e: logger.error(f"文件系统初始化失败: {e}") return False async def save_accounts(self, accounts: List[Dict[str, Any]]) -> bool: """保存账号配置""" try: self._ensure_dir() config = await self.load_config() config["accounts"] = accounts return await self.save_config(config) except Exception as e: logger.error(f"保存账号配置失败: {e}") return False async def load_accounts(self) -> List[Dict[str, Any]]: """加载账号配置""" config = await self.load_config() return config.get("accounts", []) async def save_config(self, config: Dict[str, Any]) -> bool: """保存完整配置""" try: self._ensure_dir() with open(self.config_file, "w", encoding="utf-8") as f: json.dump(config, f, indent=2, ensure_ascii=False) return True except Exception as e: logger.error(f"保存配置失败: {e}") return False async def load_config(self) -> Dict[str, Any]: """加载完整配置""" try: if self.config_file.exists(): with open(self.config_file, "r", encoding="utf-8") as f: return json.load(f) except Exception as e: logger.error(f"加载配置失败: {e}") return {} async def save_admin_config(self, admin_config: Dict[str, Any]) -> bool: """保存管理员配置""" try: self._ensure_dir() with open(self.admin_config_file, "w", encoding="utf-8") as f: json.dump(admin_config, f, indent=2, ensure_ascii=False) return True except Exception as e: logger.error(f"保存管理员配置失败: {e}") return False async def load_admin_config(self) -> Dict[str, Any]: """加载管理员配置""" try: if self.admin_config_file.exists(): with open(self.admin_config_file, "r", encoding="utf-8") as f: return json.load(f) except Exception as e: logger.error(f"加载管理员配置失败: {e}") return {} async def get_config_hash(self) -> str: """获取配置文件的哈希值""" try: if self.config_file.exists(): content = self.config_file.read_bytes() return hashlib.md5(content).hexdigest() except Exception as e: logger.error(f"获取配置哈希失败: {e}") return "" class SQLDatabase(DatabaseInterface): """SQL数据库实现""" def __init__(self, database_url: str): self.database_url = database_url self.pool = None self._db_type = self._detect_db_type(database_url) def _detect_db_type(self, url: str) -> str: """检测数据库类型""" if url.startswith("postgresql://") or url.startswith("postgres://"): return "postgresql" elif url.startswith("mysql://") or url.startswith("mysql+"): return "mysql" elif url.startswith("sqlite://"): return "sqlite" else: return "unknown" async def initialize(self) -> bool: """初始化SQL数据库连接和表结构""" try: if self._db_type == "postgresql": await self._init_postgresql() elif self._db_type == "mysql": await self._init_mysql() elif self._db_type == "sqlite": await self._init_sqlite() else: logger.error(f"不支持的数据库类型: {self._db_type}") return False await self._create_tables() return True except Exception as e: logger.error(f"SQL数据库初始化失败: {e}") return False async def _init_postgresql(self): """初始化PostgreSQL连接""" try: import asyncpg self.pool = await asyncpg.create_pool(self.database_url) except ImportError: raise ImportError("请安装 asyncpg: pip install asyncpg") async def _init_mysql(self): """初始化MySQL连接""" try: import aiomysql # 解析连接URL from urllib.parse import urlparse parsed = urlparse(self.database_url) self.pool = await aiomysql.create_pool( host=parsed.hostname, port=parsed.port or 3306, user=parsed.username, password=parsed.password, db=parsed.path.lstrip('/'), charset='utf8mb4' ) except ImportError: raise ImportError("请安装 aiomysql: pip install aiomysql") async def _init_sqlite(self): """初始化SQLite连接""" try: import aiosqlite db_path = self.database_url.replace("sqlite://", "") self.db_path = db_path except ImportError: raise ImportError("请安装 aiosqlite: pip install aiosqlite") async def _create_tables(self): """创建数据表""" if self._db_type == "postgresql": await self._create_tables_postgresql() elif self._db_type == "mysql": await self._create_tables_mysql() elif self._db_type == "sqlite": await self._create_tables_sqlite() async def _create_tables_postgresql(self): """创建PostgreSQL表""" async with self.pool.acquire() as conn: await conn.execute(""" CREATE TABLE IF NOT EXISTS kiro_config ( id SERIAL PRIMARY KEY, key VARCHAR(255) UNIQUE NOT NULL, value JSONB NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) await conn.execute(""" CREATE TABLE IF NOT EXISTS kiro_admin ( id SERIAL PRIMARY KEY, key VARCHAR(255) UNIQUE NOT NULL, value JSONB NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) async def _create_tables_mysql(self): """创建MySQL表""" async with self.pool.acquire() as conn: async with conn.cursor() as cursor: await cursor.execute(""" CREATE TABLE IF NOT EXISTS kiro_config ( id INT AUTO_INCREMENT PRIMARY KEY, `key` VARCHAR(255) UNIQUE NOT NULL, value JSON NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 """) await cursor.execute(""" CREATE TABLE IF NOT EXISTS kiro_admin ( id INT AUTO_INCREMENT PRIMARY KEY, `key` VARCHAR(255) UNIQUE NOT NULL, value JSON NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 """) async def _create_tables_sqlite(self): """创建SQLite表""" import aiosqlite async with aiosqlite.connect(self.db_path) as db: await db.execute(""" CREATE TABLE IF NOT EXISTS kiro_config ( id INTEGER PRIMARY KEY AUTOINCREMENT, key TEXT UNIQUE NOT NULL, value TEXT NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) await db.execute(""" CREATE TABLE IF NOT EXISTS kiro_admin ( id INTEGER PRIMARY KEY AUTOINCREMENT, key TEXT UNIQUE NOT NULL, value TEXT NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) await db.commit() async def _get_config_value(self, table: str, key: str) -> Optional[Dict[str, Any]]: """从指定表获取配置值""" try: if self._db_type == "postgresql": async with self.pool.acquire() as conn: row = await conn.fetchrow(f"SELECT value FROM {table} WHERE key = $1", key) return row['value'] if row else None elif self._db_type == "mysql": async with self.pool.acquire() as conn: async with conn.cursor() as cursor: await cursor.execute(f"SELECT value FROM {table} WHERE `key` = %s", (key,)) row = await cursor.fetchone() return json.loads(row[0]) if row else None elif self._db_type == "sqlite": import aiosqlite async with aiosqlite.connect(self.db_path) as db: cursor = await db.execute(f"SELECT value FROM {table} WHERE key = ?", (key,)) row = await cursor.fetchone() return json.loads(row[0]) if row else None except Exception as e: logger.error(f"获取配置失败 {table}.{key}: {e}") return None async def _set_config_value(self, table: str, key: str, value: Dict[str, Any]) -> bool: """设置配置值到指定表""" try: if self._db_type == "postgresql": async with self.pool.acquire() as conn: await conn.execute(f""" INSERT INTO {table} (key, value) VALUES ($1, $2) ON CONFLICT (key) DO UPDATE SET value = $2, updated_at = CURRENT_TIMESTAMP """, key, json.dumps(value)) elif self._db_type == "mysql": async with self.pool.acquire() as conn: async with conn.cursor() as cursor: await cursor.execute(f""" INSERT INTO {table} (`key`, value) VALUES (%s, %s) ON DUPLICATE KEY UPDATE value = %s, updated_at = CURRENT_TIMESTAMP """, (key, json.dumps(value), json.dumps(value))) await conn.commit() elif self._db_type == "sqlite": import aiosqlite async with aiosqlite.connect(self.db_path) as db: await db.execute(f""" INSERT OR REPLACE INTO {table} (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP) """, (key, json.dumps(value))) await db.commit() return True except Exception as e: logger.error(f"设置配置失败 {table}.{key}: {e}") return False async def save_accounts(self, accounts: List[Dict[str, Any]]) -> bool: """保存账号配置""" config = await self.load_config() config["accounts"] = accounts return await self.save_config(config) async def load_accounts(self) -> List[Dict[str, Any]]: """加载账号配置""" config = await self.load_config() return config.get("accounts", []) async def save_config(self, config: Dict[str, Any]) -> bool: """保存完整配置""" return await self._set_config_value("kiro_config", "main", config) async def load_config(self) -> Dict[str, Any]: """加载完整配置""" config = await self._get_config_value("kiro_config", "main") return config or {} async def save_admin_config(self, admin_config: Dict[str, Any]) -> bool: """保存管理员配置""" return await self._set_config_value("kiro_admin", "main", admin_config) async def load_admin_config(self) -> Dict[str, Any]: """加载管理员配置""" config = await self._get_config_value("kiro_admin", "main") return config or {} async def get_config_hash(self) -> str: """获取配置的哈希值""" try: config = await self.load_config() content = json.dumps(config, sort_keys=True) return hashlib.md5(content.encode()).hexdigest() except Exception as e: logger.error(f"获取配置哈希失败: {e}") return "" def _get_account_key(account: Dict[str, Any]) -> str: """获取账号唯一标识用于去重""" if account.get("email"): return f"email:{account['email']}" elif account.get("token"): return f"token:{account['token'][:32]}" elif account.get("id"): return f"id:{account['id']}" return f"hash:{hash(json.dumps(account, sort_keys=True))}" def _merge_accounts(local_accounts: List[Dict[str, Any]], remote_accounts: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """合并本地和远程账号并去重,远程优先""" seen_keys = set() merged = [] for account in remote_accounts: key = _get_account_key(account) if key not in seen_keys: seen_keys.add(key) merged.append(account) for account in local_accounts: key = _get_account_key(account) if key not in seen_keys: seen_keys.add(key) merged.append(account) logger.info(f"从本地合并账号: {key}") return merged # 数据库工厂 def create_database() -> DatabaseInterface: """根据环境变量创建数据库实例""" database_url = os.getenv("DATABASE_URL") if database_url: candidate = SQLDatabase(database_url) if candidate._db_type != "unknown": logger.info(f"使用远程数据库: {candidate._db_type}") return candidate logger.warning("DATABASE_URL 格式不支持,将回退到文件系统数据库") else: from ..config import DATA_DIR logger.info(f"使用文件系统数据库: {DATA_DIR}") return FileSystemDatabase(DATA_DIR) from ..config import DATA_DIR logger.info(f"使用文件系统数据库: {DATA_DIR}") return FileSystemDatabase(DATA_DIR) # 全局数据库实例 _db_instance: Optional[DatabaseInterface] = None _merge_completed: bool = False async def get_database() -> DatabaseInterface: """获取数据库实例(单例模式),首次启动时自动合并本地账号到远程""" global _db_instance, _merge_completed if _db_instance is None: candidate = create_database() ok = await candidate.initialize() if not ok: from ..config import DATA_DIR logger.warning(f"数据库初始化失败,将回退到文件系统数据库: {DATA_DIR}") candidate = FileSystemDatabase(DATA_DIR) await candidate.initialize() _db_instance = candidate if isinstance(_db_instance, SQLDatabase) and not _merge_completed: await _auto_merge_local_to_remote() _merge_completed = True return _db_instance async def _auto_merge_local_to_remote(): """自动将本地账号合并到远程数据库""" global _db_instance try: from ..config import DATA_DIR local_db = FileSystemDatabase(DATA_DIR) await local_db.initialize() local_accounts = await local_db.load_accounts() if not local_accounts: logger.info("本地无账号,跳过合并") return remote_accounts = await _db_instance.load_accounts() merged = _merge_accounts(local_accounts, remote_accounts) if len(merged) > len(remote_accounts): await _db_instance.save_accounts(merged) new_count = len(merged) - len(remote_accounts) logger.info(f"已将 {new_count} 个本地账号合并到远程数据库,共 {len(merged)} 个账号") else: logger.info(f"所有本地账号已存在于远程,无需合并") except Exception as e: logger.error(f"自动合并账号失败: {e}") class SyncManager: """双向数据同步管理器""" def __init__(self, local_db: FileSystemDatabase, remote_db: SQLDatabase): self.local_db = local_db self.remote_db = remote_db self._local_hash: str = "" self._remote_hash: str = "" self._sync_lock = asyncio.Lock() self._running = False self._sync_task: Optional[asyncio.Task] = None self._sync_interval = int(os.getenv("SYNC_INTERVAL", "30")) self._callbacks: List[Callable] = [] def register_callback(self, callback: Callable): """注册同步完成回调""" self._callbacks.append(callback) async def _notify_callbacks(self): """通知所有回调""" for callback in self._callbacks: try: if asyncio.iscoroutinefunction(callback): await callback() else: callback() except Exception as e: logger.error(f"同步回调执行失败: {e}") async def start(self): """启动同步""" if self._running: return self._running = True await self._initial_sync() self._sync_task = asyncio.create_task(self._sync_loop()) logger.info(f"双向同步已启动,间隔 {self._sync_interval} 秒") async def stop(self): """停止同步""" self._running = False if self._sync_task: self._sync_task.cancel() try: await self._sync_task except asyncio.CancelledError: pass logger.info("双向同步已停止") async def _initial_sync(self): """初始同步:合并两端数据""" async with self._sync_lock: try: local_accounts = await self.local_db.load_accounts() remote_accounts = await self.remote_db.load_accounts() merged = _merge_accounts_bidirectional(local_accounts, remote_accounts) await self.local_db.save_accounts(merged) await self.remote_db.save_accounts(merged) self._local_hash = await self.local_db.get_config_hash() self._remote_hash = await self.remote_db.get_config_hash() logger.info(f"初始同步完成,共 {len(merged)} 个账号") await self._notify_callbacks() except Exception as e: logger.error(f"初始同步失败: {e}") async def _sync_loop(self): """同步循环""" while self._running: try: await asyncio.sleep(self._sync_interval) await self._check_and_sync() except asyncio.CancelledError: break except Exception as e: logger.error(f"同步循环异常: {e}") async def _check_and_sync(self): """检查并同步变更""" async with self._sync_lock: try: current_local_hash = await self.local_db.get_config_hash() current_remote_hash = await self.remote_db.get_config_hash() local_changed = current_local_hash != self._local_hash remote_changed = current_remote_hash != self._remote_hash if not local_changed and not remote_changed: return local_accounts = await self.local_db.load_accounts() remote_accounts = await self.remote_db.load_accounts() if local_changed and remote_changed: logger.info("检测到本地和远程都有变更,执行合并") merged = _merge_accounts_bidirectional(local_accounts, remote_accounts) await self.local_db.save_accounts(merged) await self.remote_db.save_accounts(merged) elif local_changed: logger.info("检测到本地变更,同步到远程") merged = _merge_accounts_bidirectional(local_accounts, remote_accounts) await self.remote_db.save_accounts(merged) await self.local_db.save_accounts(merged) elif remote_changed: logger.info("检测到远程变更,同步到本地") merged = _merge_accounts_bidirectional(local_accounts, remote_accounts) await self.local_db.save_accounts(merged) await self.remote_db.save_accounts(merged) self._local_hash = await self.local_db.get_config_hash() self._remote_hash = await self.remote_db.get_config_hash() await self._notify_callbacks() except Exception as e: logger.error(f"同步检查失败: {e}") async def force_sync(self, source: str = "merge"): """强制同步 Args: source: 同步源 - "local" (本地覆盖远程), "remote" (远程覆盖本地), "merge" (合并) """ async with self._sync_lock: try: if source == "local": accounts = await self.local_db.load_accounts() await self.remote_db.save_accounts(accounts) logger.info(f"强制同步:本地 -> 远程,{len(accounts)} 个账号") elif source == "remote": accounts = await self.remote_db.load_accounts() await self.local_db.save_accounts(accounts) logger.info(f"强制同步:远程 -> 本地,{len(accounts)} 个账号") else: local_accounts = await self.local_db.load_accounts() remote_accounts = await self.remote_db.load_accounts() merged = _merge_accounts_bidirectional(local_accounts, remote_accounts) await self.local_db.save_accounts(merged) await self.remote_db.save_accounts(merged) logger.info(f"强制同步:双向合并,{len(merged)} 个账号") self._local_hash = await self.local_db.get_config_hash() self._remote_hash = await self.remote_db.get_config_hash() await self._notify_callbacks() return True except Exception as e: logger.error(f"强制同步失败: {e}") return False async def get_sync_status(self) -> Dict[str, Any]: """获取同步状态""" try: local_accounts = await self.local_db.load_accounts() remote_accounts = await self.remote_db.load_accounts() current_local_hash = await self.local_db.get_config_hash() current_remote_hash = await self.remote_db.get_config_hash() return { "enabled": True, "running": self._running, "sync_interval": self._sync_interval, "local_accounts": len(local_accounts), "remote_accounts": len(remote_accounts), "local_changed": current_local_hash != self._local_hash, "remote_changed": current_remote_hash != self._remote_hash, "in_sync": current_local_hash == current_remote_hash, } except Exception as e: return { "enabled": True, "running": self._running, "error": str(e), } def _merge_accounts_bidirectional( local_accounts: List[Dict[str, Any]], remote_accounts: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: """双向合并账号,基于更新时间或合并策略""" accounts_map: Dict[str, Dict[str, Any]] = {} for account in local_accounts: key = _get_account_key(account) account["_source"] = "local" accounts_map[key] = account for account in remote_accounts: key = _get_account_key(account) if key in accounts_map: local_acc = accounts_map[key] local_updated = local_acc.get("updated_at", 0) remote_updated = account.get("updated_at", 0) if remote_updated > local_updated: account["_source"] = "remote" accounts_map[key] = account else: account["_source"] = "remote" accounts_map[key] = account merged = [] for account in accounts_map.values(): clean_account = {k: v for k, v in account.items() if not k.startswith("_")} merged.append(clean_account) return merged _sync_manager: Optional[SyncManager] = None async def get_sync_manager() -> Optional[SyncManager]: """获取同步管理器实例""" return _sync_manager async def start_sync_manager(): """启动同步管理器""" global _sync_manager database_url = os.getenv("DATABASE_URL") if not database_url: logger.info("未配置 DATABASE_URL,跳过同步管理器") return from ..config import DATA_DIR local_db = FileSystemDatabase(DATA_DIR) await local_db.initialize() remote_db = SQLDatabase(database_url) if remote_db._db_type == "unknown": logger.warning("DATABASE_URL 格式不支持,跳过同步管理器") return ok = await remote_db.initialize() if not ok: logger.warning("远程数据库初始化失败,跳过同步管理器") return _sync_manager = SyncManager(local_db, remote_db) await _sync_manager.start() async def stop_sync_manager(): """停止同步管理器""" global _sync_manager if _sync_manager: await _sync_manager.stop() _sync_manager = None