|
|
"""数据库抽象层 - 支持文件系统和远程SQL数据库,以及双向同步""" |
|
|
import os |
|
|
import json |
|
|
import asyncio |
|
|
import hashlib |
|
|
from abc import ABC, abstractmethod |
|
|
from typing import List, Dict, Any, Optional, Callable |
|
|
from pathlib import Path |
|
|
import logging |
|
|
import time |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class DatabaseInterface(ABC): |
|
|
"""数据库接口抽象类""" |
|
|
|
|
|
@abstractmethod |
|
|
async def save_accounts(self, accounts: List[Dict[str, Any]]) -> bool: |
|
|
"""保存账号配置""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def load_accounts(self) -> List[Dict[str, Any]]: |
|
|
"""加载账号配置""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def save_config(self, config: Dict[str, Any]) -> bool: |
|
|
"""保存完整配置""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def load_config(self) -> Dict[str, Any]: |
|
|
"""加载完整配置""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def save_admin_config(self, admin_config: Dict[str, Any]) -> bool: |
|
|
"""保存管理员配置""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def load_admin_config(self) -> Dict[str, Any]: |
|
|
"""加载管理员配置""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def initialize(self) -> bool: |
|
|
"""初始化数据库""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def get_config_hash(self) -> str: |
|
|
"""获取配置哈希值用于变更检测""" |
|
|
pass |
|
|
|
|
|
|
|
|
class FileSystemDatabase(DatabaseInterface): |
|
|
"""文件系统数据库实现(原有逻辑)""" |
|
|
|
|
|
def __init__(self, data_dir: Path): |
|
|
self.data_dir = data_dir |
|
|
self.config_file = data_dir / "config.json" |
|
|
self.admin_config_file = data_dir / "admin.json" |
|
|
|
|
|
def _ensure_dir(self): |
|
|
"""确保目录存在""" |
|
|
self.data_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
async def initialize(self) -> bool: |
|
|
"""初始化文件系统""" |
|
|
try: |
|
|
self._ensure_dir() |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"文件系统初始化失败: {e}") |
|
|
return False |
|
|
|
|
|
async def save_accounts(self, accounts: List[Dict[str, Any]]) -> bool: |
|
|
"""保存账号配置""" |
|
|
try: |
|
|
self._ensure_dir() |
|
|
config = await self.load_config() |
|
|
config["accounts"] = accounts |
|
|
return await self.save_config(config) |
|
|
except Exception as e: |
|
|
logger.error(f"保存账号配置失败: {e}") |
|
|
return False |
|
|
|
|
|
async def load_accounts(self) -> List[Dict[str, Any]]: |
|
|
"""加载账号配置""" |
|
|
config = await self.load_config() |
|
|
return config.get("accounts", []) |
|
|
|
|
|
async def save_config(self, config: Dict[str, Any]) -> bool: |
|
|
"""保存完整配置""" |
|
|
try: |
|
|
self._ensure_dir() |
|
|
with open(self.config_file, "w", encoding="utf-8") as f: |
|
|
json.dump(config, f, indent=2, ensure_ascii=False) |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"保存配置失败: {e}") |
|
|
return False |
|
|
|
|
|
async def load_config(self) -> Dict[str, Any]: |
|
|
"""加载完整配置""" |
|
|
try: |
|
|
if self.config_file.exists(): |
|
|
with open(self.config_file, "r", encoding="utf-8") as f: |
|
|
return json.load(f) |
|
|
except Exception as e: |
|
|
logger.error(f"加载配置失败: {e}") |
|
|
return {} |
|
|
|
|
|
async def save_admin_config(self, admin_config: Dict[str, Any]) -> bool: |
|
|
"""保存管理员配置""" |
|
|
try: |
|
|
self._ensure_dir() |
|
|
with open(self.admin_config_file, "w", encoding="utf-8") as f: |
|
|
json.dump(admin_config, f, indent=2, ensure_ascii=False) |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"保存管理员配置失败: {e}") |
|
|
return False |
|
|
|
|
|
async def load_admin_config(self) -> Dict[str, Any]: |
|
|
"""加载管理员配置""" |
|
|
try: |
|
|
if self.admin_config_file.exists(): |
|
|
with open(self.admin_config_file, "r", encoding="utf-8") as f: |
|
|
return json.load(f) |
|
|
except Exception as e: |
|
|
logger.error(f"加载管理员配置失败: {e}") |
|
|
return {} |
|
|
|
|
|
async def get_config_hash(self) -> str: |
|
|
"""获取配置文件的哈希值""" |
|
|
try: |
|
|
if self.config_file.exists(): |
|
|
content = self.config_file.read_bytes() |
|
|
return hashlib.md5(content).hexdigest() |
|
|
except Exception as e: |
|
|
logger.error(f"获取配置哈希失败: {e}") |
|
|
return "" |
|
|
|
|
|
|
|
|
class SQLDatabase(DatabaseInterface): |
|
|
"""SQL数据库实现""" |
|
|
|
|
|
def __init__(self, database_url: str): |
|
|
self.database_url = database_url |
|
|
self.pool = None |
|
|
self._db_type = self._detect_db_type(database_url) |
|
|
|
|
|
def _detect_db_type(self, url: str) -> str: |
|
|
"""检测数据库类型""" |
|
|
if url.startswith("postgresql://") or url.startswith("postgres://"): |
|
|
return "postgresql" |
|
|
elif url.startswith("mysql://") or url.startswith("mysql+"): |
|
|
return "mysql" |
|
|
elif url.startswith("sqlite://"): |
|
|
return "sqlite" |
|
|
else: |
|
|
return "unknown" |
|
|
|
|
|
async def initialize(self) -> bool: |
|
|
"""初始化SQL数据库连接和表结构""" |
|
|
try: |
|
|
if self._db_type == "postgresql": |
|
|
await self._init_postgresql() |
|
|
elif self._db_type == "mysql": |
|
|
await self._init_mysql() |
|
|
elif self._db_type == "sqlite": |
|
|
await self._init_sqlite() |
|
|
else: |
|
|
logger.error(f"不支持的数据库类型: {self._db_type}") |
|
|
return False |
|
|
|
|
|
await self._create_tables() |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"SQL数据库初始化失败: {e}") |
|
|
return False |
|
|
|
|
|
async def _init_postgresql(self): |
|
|
"""初始化PostgreSQL连接""" |
|
|
try: |
|
|
import asyncpg |
|
|
self.pool = await asyncpg.create_pool(self.database_url) |
|
|
except ImportError: |
|
|
raise ImportError("请安装 asyncpg: pip install asyncpg") |
|
|
|
|
|
async def _init_mysql(self): |
|
|
"""初始化MySQL连接""" |
|
|
try: |
|
|
import aiomysql |
|
|
|
|
|
from urllib.parse import urlparse |
|
|
parsed = urlparse(self.database_url) |
|
|
self.pool = await aiomysql.create_pool( |
|
|
host=parsed.hostname, |
|
|
port=parsed.port or 3306, |
|
|
user=parsed.username, |
|
|
password=parsed.password, |
|
|
db=parsed.path.lstrip('/'), |
|
|
charset='utf8mb4' |
|
|
) |
|
|
except ImportError: |
|
|
raise ImportError("请安装 aiomysql: pip install aiomysql") |
|
|
|
|
|
async def _init_sqlite(self): |
|
|
"""初始化SQLite连接""" |
|
|
try: |
|
|
import aiosqlite |
|
|
db_path = self.database_url.replace("sqlite://", "") |
|
|
self.db_path = db_path |
|
|
except ImportError: |
|
|
raise ImportError("请安装 aiosqlite: pip install aiosqlite") |
|
|
|
|
|
async def _create_tables(self): |
|
|
"""创建数据表""" |
|
|
if self._db_type == "postgresql": |
|
|
await self._create_tables_postgresql() |
|
|
elif self._db_type == "mysql": |
|
|
await self._create_tables_mysql() |
|
|
elif self._db_type == "sqlite": |
|
|
await self._create_tables_sqlite() |
|
|
|
|
|
async def _create_tables_postgresql(self): |
|
|
"""创建PostgreSQL表""" |
|
|
async with self.pool.acquire() as conn: |
|
|
await conn.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS kiro_config ( |
|
|
id SERIAL PRIMARY KEY, |
|
|
key VARCHAR(255) UNIQUE NOT NULL, |
|
|
value JSONB NOT NULL, |
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP |
|
|
) |
|
|
""") |
|
|
|
|
|
await conn.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS kiro_admin ( |
|
|
id SERIAL PRIMARY KEY, |
|
|
key VARCHAR(255) UNIQUE NOT NULL, |
|
|
value JSONB NOT NULL, |
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP |
|
|
) |
|
|
""") |
|
|
|
|
|
async def _create_tables_mysql(self): |
|
|
"""创建MySQL表""" |
|
|
async with self.pool.acquire() as conn: |
|
|
async with conn.cursor() as cursor: |
|
|
await cursor.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS kiro_config ( |
|
|
id INT AUTO_INCREMENT PRIMARY KEY, |
|
|
`key` VARCHAR(255) UNIQUE NOT NULL, |
|
|
value JSON NOT NULL, |
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP |
|
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 |
|
|
""") |
|
|
|
|
|
await cursor.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS kiro_admin ( |
|
|
id INT AUTO_INCREMENT PRIMARY KEY, |
|
|
`key` VARCHAR(255) UNIQUE NOT NULL, |
|
|
value JSON NOT NULL, |
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP |
|
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 |
|
|
""") |
|
|
|
|
|
async def _create_tables_sqlite(self): |
|
|
"""创建SQLite表""" |
|
|
import aiosqlite |
|
|
async with aiosqlite.connect(self.db_path) as db: |
|
|
await db.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS kiro_config ( |
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
|
key TEXT UNIQUE NOT NULL, |
|
|
value TEXT NOT NULL, |
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP |
|
|
) |
|
|
""") |
|
|
|
|
|
await db.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS kiro_admin ( |
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
|
key TEXT UNIQUE NOT NULL, |
|
|
value TEXT NOT NULL, |
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP |
|
|
) |
|
|
""") |
|
|
await db.commit() |
|
|
|
|
|
async def _get_config_value(self, table: str, key: str) -> Optional[Dict[str, Any]]: |
|
|
"""从指定表获取配置值""" |
|
|
try: |
|
|
if self._db_type == "postgresql": |
|
|
async with self.pool.acquire() as conn: |
|
|
row = await conn.fetchrow(f"SELECT value FROM {table} WHERE key = $1", key) |
|
|
return row['value'] if row else None |
|
|
|
|
|
elif self._db_type == "mysql": |
|
|
async with self.pool.acquire() as conn: |
|
|
async with conn.cursor() as cursor: |
|
|
await cursor.execute(f"SELECT value FROM {table} WHERE `key` = %s", (key,)) |
|
|
row = await cursor.fetchone() |
|
|
return json.loads(row[0]) if row else None |
|
|
|
|
|
elif self._db_type == "sqlite": |
|
|
import aiosqlite |
|
|
async with aiosqlite.connect(self.db_path) as db: |
|
|
cursor = await db.execute(f"SELECT value FROM {table} WHERE key = ?", (key,)) |
|
|
row = await cursor.fetchone() |
|
|
return json.loads(row[0]) if row else None |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"获取配置失败 {table}.{key}: {e}") |
|
|
return None |
|
|
|
|
|
async def _set_config_value(self, table: str, key: str, value: Dict[str, Any]) -> bool: |
|
|
"""设置配置值到指定表""" |
|
|
try: |
|
|
if self._db_type == "postgresql": |
|
|
async with self.pool.acquire() as conn: |
|
|
await conn.execute(f""" |
|
|
INSERT INTO {table} (key, value) VALUES ($1, $2) |
|
|
ON CONFLICT (key) DO UPDATE SET value = $2, updated_at = CURRENT_TIMESTAMP |
|
|
""", key, json.dumps(value)) |
|
|
|
|
|
elif self._db_type == "mysql": |
|
|
async with self.pool.acquire() as conn: |
|
|
async with conn.cursor() as cursor: |
|
|
await cursor.execute(f""" |
|
|
INSERT INTO {table} (`key`, value) VALUES (%s, %s) |
|
|
ON DUPLICATE KEY UPDATE value = %s, updated_at = CURRENT_TIMESTAMP |
|
|
""", (key, json.dumps(value), json.dumps(value))) |
|
|
await conn.commit() |
|
|
|
|
|
elif self._db_type == "sqlite": |
|
|
import aiosqlite |
|
|
async with aiosqlite.connect(self.db_path) as db: |
|
|
await db.execute(f""" |
|
|
INSERT OR REPLACE INTO {table} (key, value, updated_at) |
|
|
VALUES (?, ?, CURRENT_TIMESTAMP) |
|
|
""", (key, json.dumps(value))) |
|
|
await db.commit() |
|
|
|
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"设置配置失败 {table}.{key}: {e}") |
|
|
return False |
|
|
|
|
|
async def save_accounts(self, accounts: List[Dict[str, Any]]) -> bool: |
|
|
"""保存账号配置""" |
|
|
config = await self.load_config() |
|
|
config["accounts"] = accounts |
|
|
return await self.save_config(config) |
|
|
|
|
|
async def load_accounts(self) -> List[Dict[str, Any]]: |
|
|
"""加载账号配置""" |
|
|
config = await self.load_config() |
|
|
return config.get("accounts", []) |
|
|
|
|
|
async def save_config(self, config: Dict[str, Any]) -> bool: |
|
|
"""保存完整配置""" |
|
|
return await self._set_config_value("kiro_config", "main", config) |
|
|
|
|
|
async def load_config(self) -> Dict[str, Any]: |
|
|
"""加载完整配置""" |
|
|
config = await self._get_config_value("kiro_config", "main") |
|
|
return config or {} |
|
|
|
|
|
async def save_admin_config(self, admin_config: Dict[str, Any]) -> bool: |
|
|
"""保存管理员配置""" |
|
|
return await self._set_config_value("kiro_admin", "main", admin_config) |
|
|
|
|
|
async def load_admin_config(self) -> Dict[str, Any]: |
|
|
"""加载管理员配置""" |
|
|
config = await self._get_config_value("kiro_admin", "main") |
|
|
return config or {} |
|
|
|
|
|
async def get_config_hash(self) -> str: |
|
|
"""获取配置的哈希值""" |
|
|
try: |
|
|
config = await self.load_config() |
|
|
content = json.dumps(config, sort_keys=True) |
|
|
return hashlib.md5(content.encode()).hexdigest() |
|
|
except Exception as e: |
|
|
logger.error(f"获取配置哈希失败: {e}") |
|
|
return "" |
|
|
|
|
|
|
|
|
def _get_account_key(account: Dict[str, Any]) -> str: |
|
|
"""获取账号唯一标识用于去重""" |
|
|
if account.get("email"): |
|
|
return f"email:{account['email']}" |
|
|
elif account.get("token"): |
|
|
return f"token:{account['token'][:32]}" |
|
|
elif account.get("id"): |
|
|
return f"id:{account['id']}" |
|
|
return f"hash:{hash(json.dumps(account, sort_keys=True))}" |
|
|
|
|
|
|
|
|
def _merge_accounts(local_accounts: List[Dict[str, Any]], remote_accounts: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
|
|
"""合并本地和远程账号并去重,远程优先""" |
|
|
seen_keys = set() |
|
|
merged = [] |
|
|
|
|
|
for account in remote_accounts: |
|
|
key = _get_account_key(account) |
|
|
if key not in seen_keys: |
|
|
seen_keys.add(key) |
|
|
merged.append(account) |
|
|
|
|
|
for account in local_accounts: |
|
|
key = _get_account_key(account) |
|
|
if key not in seen_keys: |
|
|
seen_keys.add(key) |
|
|
merged.append(account) |
|
|
logger.info(f"从本地合并账号: {key}") |
|
|
|
|
|
return merged |
|
|
|
|
|
|
|
|
|
|
|
def create_database() -> DatabaseInterface: |
|
|
"""根据环境变量创建数据库实例""" |
|
|
database_url = os.getenv("DATABASE_URL") |
|
|
|
|
|
if database_url: |
|
|
candidate = SQLDatabase(database_url) |
|
|
if candidate._db_type != "unknown": |
|
|
logger.info(f"使用远程数据库: {candidate._db_type}") |
|
|
return candidate |
|
|
logger.warning("DATABASE_URL 格式不支持,将回退到文件系统数据库") |
|
|
else: |
|
|
from ..config import DATA_DIR |
|
|
logger.info(f"使用文件系统数据库: {DATA_DIR}") |
|
|
return FileSystemDatabase(DATA_DIR) |
|
|
|
|
|
from ..config import DATA_DIR |
|
|
logger.info(f"使用文件系统数据库: {DATA_DIR}") |
|
|
return FileSystemDatabase(DATA_DIR) |
|
|
|
|
|
|
|
|
|
|
|
_db_instance: Optional[DatabaseInterface] = None |
|
|
_merge_completed: bool = False |
|
|
|
|
|
async def get_database() -> DatabaseInterface: |
|
|
"""获取数据库实例(单例模式),首次启动时自动合并本地账号到远程""" |
|
|
global _db_instance, _merge_completed |
|
|
|
|
|
if _db_instance is None: |
|
|
candidate = create_database() |
|
|
ok = await candidate.initialize() |
|
|
if not ok: |
|
|
from ..config import DATA_DIR |
|
|
logger.warning(f"数据库初始化失败,将回退到文件系统数据库: {DATA_DIR}") |
|
|
candidate = FileSystemDatabase(DATA_DIR) |
|
|
await candidate.initialize() |
|
|
|
|
|
_db_instance = candidate |
|
|
|
|
|
if isinstance(_db_instance, SQLDatabase) and not _merge_completed: |
|
|
await _auto_merge_local_to_remote() |
|
|
_merge_completed = True |
|
|
|
|
|
return _db_instance |
|
|
|
|
|
|
|
|
async def _auto_merge_local_to_remote(): |
|
|
"""自动将本地账号合并到远程数据库""" |
|
|
global _db_instance |
|
|
|
|
|
try: |
|
|
from ..config import DATA_DIR |
|
|
local_db = FileSystemDatabase(DATA_DIR) |
|
|
await local_db.initialize() |
|
|
|
|
|
local_accounts = await local_db.load_accounts() |
|
|
if not local_accounts: |
|
|
logger.info("本地无账号,跳过合并") |
|
|
return |
|
|
|
|
|
remote_accounts = await _db_instance.load_accounts() |
|
|
|
|
|
merged = _merge_accounts(local_accounts, remote_accounts) |
|
|
|
|
|
if len(merged) > len(remote_accounts): |
|
|
await _db_instance.save_accounts(merged) |
|
|
new_count = len(merged) - len(remote_accounts) |
|
|
logger.info(f"已将 {new_count} 个本地账号合并到远程数据库,共 {len(merged)} 个账号") |
|
|
else: |
|
|
logger.info(f"所有本地账号已存在于远程,无需合并") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"自动合并账号失败: {e}") |
|
|
|
|
|
|
|
|
class SyncManager: |
|
|
"""双向数据同步管理器""" |
|
|
|
|
|
def __init__(self, local_db: FileSystemDatabase, remote_db: SQLDatabase): |
|
|
self.local_db = local_db |
|
|
self.remote_db = remote_db |
|
|
self._local_hash: str = "" |
|
|
self._remote_hash: str = "" |
|
|
self._sync_lock = asyncio.Lock() |
|
|
self._running = False |
|
|
self._sync_task: Optional[asyncio.Task] = None |
|
|
self._sync_interval = int(os.getenv("SYNC_INTERVAL", "30")) |
|
|
self._callbacks: List[Callable] = [] |
|
|
|
|
|
def register_callback(self, callback: Callable): |
|
|
"""注册同步完成回调""" |
|
|
self._callbacks.append(callback) |
|
|
|
|
|
async def _notify_callbacks(self): |
|
|
"""通知所有回调""" |
|
|
for callback in self._callbacks: |
|
|
try: |
|
|
if asyncio.iscoroutinefunction(callback): |
|
|
await callback() |
|
|
else: |
|
|
callback() |
|
|
except Exception as e: |
|
|
logger.error(f"同步回调执行失败: {e}") |
|
|
|
|
|
async def start(self): |
|
|
"""启动同步""" |
|
|
if self._running: |
|
|
return |
|
|
|
|
|
self._running = True |
|
|
await self._initial_sync() |
|
|
|
|
|
self._sync_task = asyncio.create_task(self._sync_loop()) |
|
|
logger.info(f"双向同步已启动,间隔 {self._sync_interval} 秒") |
|
|
|
|
|
async def stop(self): |
|
|
"""停止同步""" |
|
|
self._running = False |
|
|
if self._sync_task: |
|
|
self._sync_task.cancel() |
|
|
try: |
|
|
await self._sync_task |
|
|
except asyncio.CancelledError: |
|
|
pass |
|
|
logger.info("双向同步已停止") |
|
|
|
|
|
async def _initial_sync(self): |
|
|
"""初始同步:合并两端数据""" |
|
|
async with self._sync_lock: |
|
|
try: |
|
|
local_accounts = await self.local_db.load_accounts() |
|
|
remote_accounts = await self.remote_db.load_accounts() |
|
|
|
|
|
merged = _merge_accounts_bidirectional(local_accounts, remote_accounts) |
|
|
|
|
|
await self.local_db.save_accounts(merged) |
|
|
await self.remote_db.save_accounts(merged) |
|
|
|
|
|
self._local_hash = await self.local_db.get_config_hash() |
|
|
self._remote_hash = await self.remote_db.get_config_hash() |
|
|
|
|
|
logger.info(f"初始同步完成,共 {len(merged)} 个账号") |
|
|
await self._notify_callbacks() |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"初始同步失败: {e}") |
|
|
|
|
|
async def _sync_loop(self): |
|
|
"""同步循环""" |
|
|
while self._running: |
|
|
try: |
|
|
await asyncio.sleep(self._sync_interval) |
|
|
await self._check_and_sync() |
|
|
except asyncio.CancelledError: |
|
|
break |
|
|
except Exception as e: |
|
|
logger.error(f"同步循环异常: {e}") |
|
|
|
|
|
async def _check_and_sync(self): |
|
|
"""检查并同步变更""" |
|
|
async with self._sync_lock: |
|
|
try: |
|
|
current_local_hash = await self.local_db.get_config_hash() |
|
|
current_remote_hash = await self.remote_db.get_config_hash() |
|
|
|
|
|
local_changed = current_local_hash != self._local_hash |
|
|
remote_changed = current_remote_hash != self._remote_hash |
|
|
|
|
|
if not local_changed and not remote_changed: |
|
|
return |
|
|
|
|
|
local_accounts = await self.local_db.load_accounts() |
|
|
remote_accounts = await self.remote_db.load_accounts() |
|
|
|
|
|
if local_changed and remote_changed: |
|
|
logger.info("检测到本地和远程都有变更,执行合并") |
|
|
merged = _merge_accounts_bidirectional(local_accounts, remote_accounts) |
|
|
await self.local_db.save_accounts(merged) |
|
|
await self.remote_db.save_accounts(merged) |
|
|
|
|
|
elif local_changed: |
|
|
logger.info("检测到本地变更,同步到远程") |
|
|
merged = _merge_accounts_bidirectional(local_accounts, remote_accounts) |
|
|
await self.remote_db.save_accounts(merged) |
|
|
await self.local_db.save_accounts(merged) |
|
|
|
|
|
elif remote_changed: |
|
|
logger.info("检测到远程变更,同步到本地") |
|
|
merged = _merge_accounts_bidirectional(local_accounts, remote_accounts) |
|
|
await self.local_db.save_accounts(merged) |
|
|
await self.remote_db.save_accounts(merged) |
|
|
|
|
|
self._local_hash = await self.local_db.get_config_hash() |
|
|
self._remote_hash = await self.remote_db.get_config_hash() |
|
|
|
|
|
await self._notify_callbacks() |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"同步检查失败: {e}") |
|
|
|
|
|
async def force_sync(self, source: str = "merge"): |
|
|
"""强制同步 |
|
|
|
|
|
Args: |
|
|
source: 同步源 - "local" (本地覆盖远程), "remote" (远程覆盖本地), "merge" (合并) |
|
|
""" |
|
|
async with self._sync_lock: |
|
|
try: |
|
|
if source == "local": |
|
|
accounts = await self.local_db.load_accounts() |
|
|
await self.remote_db.save_accounts(accounts) |
|
|
logger.info(f"强制同步:本地 -> 远程,{len(accounts)} 个账号") |
|
|
|
|
|
elif source == "remote": |
|
|
accounts = await self.remote_db.load_accounts() |
|
|
await self.local_db.save_accounts(accounts) |
|
|
logger.info(f"强制同步:远程 -> 本地,{len(accounts)} 个账号") |
|
|
|
|
|
else: |
|
|
local_accounts = await self.local_db.load_accounts() |
|
|
remote_accounts = await self.remote_db.load_accounts() |
|
|
merged = _merge_accounts_bidirectional(local_accounts, remote_accounts) |
|
|
await self.local_db.save_accounts(merged) |
|
|
await self.remote_db.save_accounts(merged) |
|
|
logger.info(f"强制同步:双向合并,{len(merged)} 个账号") |
|
|
|
|
|
self._local_hash = await self.local_db.get_config_hash() |
|
|
self._remote_hash = await self.remote_db.get_config_hash() |
|
|
|
|
|
await self._notify_callbacks() |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"强制同步失败: {e}") |
|
|
return False |
|
|
|
|
|
async def get_sync_status(self) -> Dict[str, Any]: |
|
|
"""获取同步状态""" |
|
|
try: |
|
|
local_accounts = await self.local_db.load_accounts() |
|
|
remote_accounts = await self.remote_db.load_accounts() |
|
|
|
|
|
current_local_hash = await self.local_db.get_config_hash() |
|
|
current_remote_hash = await self.remote_db.get_config_hash() |
|
|
|
|
|
return { |
|
|
"enabled": True, |
|
|
"running": self._running, |
|
|
"sync_interval": self._sync_interval, |
|
|
"local_accounts": len(local_accounts), |
|
|
"remote_accounts": len(remote_accounts), |
|
|
"local_changed": current_local_hash != self._local_hash, |
|
|
"remote_changed": current_remote_hash != self._remote_hash, |
|
|
"in_sync": current_local_hash == current_remote_hash, |
|
|
} |
|
|
except Exception as e: |
|
|
return { |
|
|
"enabled": True, |
|
|
"running": self._running, |
|
|
"error": str(e), |
|
|
} |
|
|
|
|
|
|
|
|
def _merge_accounts_bidirectional( |
|
|
local_accounts: List[Dict[str, Any]], |
|
|
remote_accounts: List[Dict[str, Any]] |
|
|
) -> List[Dict[str, Any]]: |
|
|
"""双向合并账号,基于更新时间或合并策略""" |
|
|
accounts_map: Dict[str, Dict[str, Any]] = {} |
|
|
|
|
|
for account in local_accounts: |
|
|
key = _get_account_key(account) |
|
|
account["_source"] = "local" |
|
|
accounts_map[key] = account |
|
|
|
|
|
for account in remote_accounts: |
|
|
key = _get_account_key(account) |
|
|
if key in accounts_map: |
|
|
local_acc = accounts_map[key] |
|
|
local_updated = local_acc.get("updated_at", 0) |
|
|
remote_updated = account.get("updated_at", 0) |
|
|
|
|
|
if remote_updated > local_updated: |
|
|
account["_source"] = "remote" |
|
|
accounts_map[key] = account |
|
|
else: |
|
|
account["_source"] = "remote" |
|
|
accounts_map[key] = account |
|
|
|
|
|
merged = [] |
|
|
for account in accounts_map.values(): |
|
|
clean_account = {k: v for k, v in account.items() if not k.startswith("_")} |
|
|
merged.append(clean_account) |
|
|
|
|
|
return merged |
|
|
|
|
|
|
|
|
_sync_manager: Optional[SyncManager] = None |
|
|
|
|
|
|
|
|
async def get_sync_manager() -> Optional[SyncManager]: |
|
|
"""获取同步管理器实例""" |
|
|
return _sync_manager |
|
|
|
|
|
|
|
|
async def start_sync_manager(): |
|
|
"""启动同步管理器""" |
|
|
global _sync_manager |
|
|
|
|
|
database_url = os.getenv("DATABASE_URL") |
|
|
if not database_url: |
|
|
logger.info("未配置 DATABASE_URL,跳过同步管理器") |
|
|
return |
|
|
|
|
|
from ..config import DATA_DIR |
|
|
|
|
|
local_db = FileSystemDatabase(DATA_DIR) |
|
|
await local_db.initialize() |
|
|
|
|
|
remote_db = SQLDatabase(database_url) |
|
|
if remote_db._db_type == "unknown": |
|
|
logger.warning("DATABASE_URL 格式不支持,跳过同步管理器") |
|
|
return |
|
|
|
|
|
ok = await remote_db.initialize() |
|
|
if not ok: |
|
|
logger.warning("远程数据库初始化失败,跳过同步管理器") |
|
|
return |
|
|
|
|
|
_sync_manager = SyncManager(local_db, remote_db) |
|
|
await _sync_manager.start() |
|
|
|
|
|
|
|
|
async def stop_sync_manager(): |
|
|
"""停止同步管理器""" |
|
|
global _sync_manager |
|
|
if _sync_manager: |
|
|
await _sync_manager.stop() |
|
|
_sync_manager = None |
|
|
|