kiroproxy / kiro_proxy /core /database.py
KiroProxy User
Fix admin auth crash in async context
320ef8e
"""数据库抽象层 - 支持文件系统和远程SQL数据库,以及双向同步"""
import os
import json
import asyncio
import hashlib
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Callable
from pathlib import Path
import logging
import time
logger = logging.getLogger(__name__)
class DatabaseInterface(ABC):
"""数据库接口抽象类"""
@abstractmethod
async def save_accounts(self, accounts: List[Dict[str, Any]]) -> bool:
"""保存账号配置"""
pass
@abstractmethod
async def load_accounts(self) -> List[Dict[str, Any]]:
"""加载账号配置"""
pass
@abstractmethod
async def save_config(self, config: Dict[str, Any]) -> bool:
"""保存完整配置"""
pass
@abstractmethod
async def load_config(self) -> Dict[str, Any]:
"""加载完整配置"""
pass
@abstractmethod
async def save_admin_config(self, admin_config: Dict[str, Any]) -> bool:
"""保存管理员配置"""
pass
@abstractmethod
async def load_admin_config(self) -> Dict[str, Any]:
"""加载管理员配置"""
pass
@abstractmethod
async def initialize(self) -> bool:
"""初始化数据库"""
pass
@abstractmethod
async def get_config_hash(self) -> str:
"""获取配置哈希值用于变更检测"""
pass
class FileSystemDatabase(DatabaseInterface):
"""文件系统数据库实现(原有逻辑)"""
def __init__(self, data_dir: Path):
self.data_dir = data_dir
self.config_file = data_dir / "config.json"
self.admin_config_file = data_dir / "admin.json"
def _ensure_dir(self):
"""确保目录存在"""
self.data_dir.mkdir(parents=True, exist_ok=True)
async def initialize(self) -> bool:
"""初始化文件系统"""
try:
self._ensure_dir()
return True
except Exception as e:
logger.error(f"文件系统初始化失败: {e}")
return False
async def save_accounts(self, accounts: List[Dict[str, Any]]) -> bool:
"""保存账号配置"""
try:
self._ensure_dir()
config = await self.load_config()
config["accounts"] = accounts
return await self.save_config(config)
except Exception as e:
logger.error(f"保存账号配置失败: {e}")
return False
async def load_accounts(self) -> List[Dict[str, Any]]:
"""加载账号配置"""
config = await self.load_config()
return config.get("accounts", [])
async def save_config(self, config: Dict[str, Any]) -> bool:
"""保存完整配置"""
try:
self._ensure_dir()
with open(self.config_file, "w", encoding="utf-8") as f:
json.dump(config, f, indent=2, ensure_ascii=False)
return True
except Exception as e:
logger.error(f"保存配置失败: {e}")
return False
async def load_config(self) -> Dict[str, Any]:
"""加载完整配置"""
try:
if self.config_file.exists():
with open(self.config_file, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
logger.error(f"加载配置失败: {e}")
return {}
async def save_admin_config(self, admin_config: Dict[str, Any]) -> bool:
"""保存管理员配置"""
try:
self._ensure_dir()
with open(self.admin_config_file, "w", encoding="utf-8") as f:
json.dump(admin_config, f, indent=2, ensure_ascii=False)
return True
except Exception as e:
logger.error(f"保存管理员配置失败: {e}")
return False
async def load_admin_config(self) -> Dict[str, Any]:
"""加载管理员配置"""
try:
if self.admin_config_file.exists():
with open(self.admin_config_file, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
logger.error(f"加载管理员配置失败: {e}")
return {}
async def get_config_hash(self) -> str:
"""获取配置文件的哈希值"""
try:
if self.config_file.exists():
content = self.config_file.read_bytes()
return hashlib.md5(content).hexdigest()
except Exception as e:
logger.error(f"获取配置哈希失败: {e}")
return ""
class SQLDatabase(DatabaseInterface):
"""SQL数据库实现"""
def __init__(self, database_url: str):
self.database_url = database_url
self.pool = None
self._db_type = self._detect_db_type(database_url)
def _detect_db_type(self, url: str) -> str:
"""检测数据库类型"""
if url.startswith("postgresql://") or url.startswith("postgres://"):
return "postgresql"
elif url.startswith("mysql://") or url.startswith("mysql+"):
return "mysql"
elif url.startswith("sqlite://"):
return "sqlite"
else:
return "unknown"
async def initialize(self) -> bool:
"""初始化SQL数据库连接和表结构"""
try:
if self._db_type == "postgresql":
await self._init_postgresql()
elif self._db_type == "mysql":
await self._init_mysql()
elif self._db_type == "sqlite":
await self._init_sqlite()
else:
logger.error(f"不支持的数据库类型: {self._db_type}")
return False
await self._create_tables()
return True
except Exception as e:
logger.error(f"SQL数据库初始化失败: {e}")
return False
async def _init_postgresql(self):
"""初始化PostgreSQL连接"""
try:
import asyncpg
self.pool = await asyncpg.create_pool(self.database_url)
except ImportError:
raise ImportError("请安装 asyncpg: pip install asyncpg")
async def _init_mysql(self):
"""初始化MySQL连接"""
try:
import aiomysql
# 解析连接URL
from urllib.parse import urlparse
parsed = urlparse(self.database_url)
self.pool = await aiomysql.create_pool(
host=parsed.hostname,
port=parsed.port or 3306,
user=parsed.username,
password=parsed.password,
db=parsed.path.lstrip('/'),
charset='utf8mb4'
)
except ImportError:
raise ImportError("请安装 aiomysql: pip install aiomysql")
async def _init_sqlite(self):
"""初始化SQLite连接"""
try:
import aiosqlite
db_path = self.database_url.replace("sqlite://", "")
self.db_path = db_path
except ImportError:
raise ImportError("请安装 aiosqlite: pip install aiosqlite")
async def _create_tables(self):
"""创建数据表"""
if self._db_type == "postgresql":
await self._create_tables_postgresql()
elif self._db_type == "mysql":
await self._create_tables_mysql()
elif self._db_type == "sqlite":
await self._create_tables_sqlite()
async def _create_tables_postgresql(self):
"""创建PostgreSQL表"""
async with self.pool.acquire() as conn:
await conn.execute("""
CREATE TABLE IF NOT EXISTS kiro_config (
id SERIAL PRIMARY KEY,
key VARCHAR(255) UNIQUE NOT NULL,
value JSONB NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
await conn.execute("""
CREATE TABLE IF NOT EXISTS kiro_admin (
id SERIAL PRIMARY KEY,
key VARCHAR(255) UNIQUE NOT NULL,
value JSONB NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
async def _create_tables_mysql(self):
"""创建MySQL表"""
async with self.pool.acquire() as conn:
async with conn.cursor() as cursor:
await cursor.execute("""
CREATE TABLE IF NOT EXISTS kiro_config (
id INT AUTO_INCREMENT PRIMARY KEY,
`key` VARCHAR(255) UNIQUE NOT NULL,
value JSON NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
""")
await cursor.execute("""
CREATE TABLE IF NOT EXISTS kiro_admin (
id INT AUTO_INCREMENT PRIMARY KEY,
`key` VARCHAR(255) UNIQUE NOT NULL,
value JSON NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
""")
async def _create_tables_sqlite(self):
"""创建SQLite表"""
import aiosqlite
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
CREATE TABLE IF NOT EXISTS kiro_config (
id INTEGER PRIMARY KEY AUTOINCREMENT,
key TEXT UNIQUE NOT NULL,
value TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
await db.execute("""
CREATE TABLE IF NOT EXISTS kiro_admin (
id INTEGER PRIMARY KEY AUTOINCREMENT,
key TEXT UNIQUE NOT NULL,
value TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
await db.commit()
async def _get_config_value(self, table: str, key: str) -> Optional[Dict[str, Any]]:
"""从指定表获取配置值"""
try:
if self._db_type == "postgresql":
async with self.pool.acquire() as conn:
row = await conn.fetchrow(f"SELECT value FROM {table} WHERE key = $1", key)
return row['value'] if row else None
elif self._db_type == "mysql":
async with self.pool.acquire() as conn:
async with conn.cursor() as cursor:
await cursor.execute(f"SELECT value FROM {table} WHERE `key` = %s", (key,))
row = await cursor.fetchone()
return json.loads(row[0]) if row else None
elif self._db_type == "sqlite":
import aiosqlite
async with aiosqlite.connect(self.db_path) as db:
cursor = await db.execute(f"SELECT value FROM {table} WHERE key = ?", (key,))
row = await cursor.fetchone()
return json.loads(row[0]) if row else None
except Exception as e:
logger.error(f"获取配置失败 {table}.{key}: {e}")
return None
async def _set_config_value(self, table: str, key: str, value: Dict[str, Any]) -> bool:
"""设置配置值到指定表"""
try:
if self._db_type == "postgresql":
async with self.pool.acquire() as conn:
await conn.execute(f"""
INSERT INTO {table} (key, value) VALUES ($1, $2)
ON CONFLICT (key) DO UPDATE SET value = $2, updated_at = CURRENT_TIMESTAMP
""", key, json.dumps(value))
elif self._db_type == "mysql":
async with self.pool.acquire() as conn:
async with conn.cursor() as cursor:
await cursor.execute(f"""
INSERT INTO {table} (`key`, value) VALUES (%s, %s)
ON DUPLICATE KEY UPDATE value = %s, updated_at = CURRENT_TIMESTAMP
""", (key, json.dumps(value), json.dumps(value)))
await conn.commit()
elif self._db_type == "sqlite":
import aiosqlite
async with aiosqlite.connect(self.db_path) as db:
await db.execute(f"""
INSERT OR REPLACE INTO {table} (key, value, updated_at)
VALUES (?, ?, CURRENT_TIMESTAMP)
""", (key, json.dumps(value)))
await db.commit()
return True
except Exception as e:
logger.error(f"设置配置失败 {table}.{key}: {e}")
return False
async def save_accounts(self, accounts: List[Dict[str, Any]]) -> bool:
"""保存账号配置"""
config = await self.load_config()
config["accounts"] = accounts
return await self.save_config(config)
async def load_accounts(self) -> List[Dict[str, Any]]:
"""加载账号配置"""
config = await self.load_config()
return config.get("accounts", [])
async def save_config(self, config: Dict[str, Any]) -> bool:
"""保存完整配置"""
return await self._set_config_value("kiro_config", "main", config)
async def load_config(self) -> Dict[str, Any]:
"""加载完整配置"""
config = await self._get_config_value("kiro_config", "main")
return config or {}
async def save_admin_config(self, admin_config: Dict[str, Any]) -> bool:
"""保存管理员配置"""
return await self._set_config_value("kiro_admin", "main", admin_config)
async def load_admin_config(self) -> Dict[str, Any]:
"""加载管理员配置"""
config = await self._get_config_value("kiro_admin", "main")
return config or {}
async def get_config_hash(self) -> str:
"""获取配置的哈希值"""
try:
config = await self.load_config()
content = json.dumps(config, sort_keys=True)
return hashlib.md5(content.encode()).hexdigest()
except Exception as e:
logger.error(f"获取配置哈希失败: {e}")
return ""
def _get_account_key(account: Dict[str, Any]) -> str:
"""获取账号唯一标识用于去重"""
if account.get("email"):
return f"email:{account['email']}"
elif account.get("token"):
return f"token:{account['token'][:32]}"
elif account.get("id"):
return f"id:{account['id']}"
return f"hash:{hash(json.dumps(account, sort_keys=True))}"
def _merge_accounts(local_accounts: List[Dict[str, Any]], remote_accounts: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""合并本地和远程账号并去重,远程优先"""
seen_keys = set()
merged = []
for account in remote_accounts:
key = _get_account_key(account)
if key not in seen_keys:
seen_keys.add(key)
merged.append(account)
for account in local_accounts:
key = _get_account_key(account)
if key not in seen_keys:
seen_keys.add(key)
merged.append(account)
logger.info(f"从本地合并账号: {key}")
return merged
# 数据库工厂
def create_database() -> DatabaseInterface:
"""根据环境变量创建数据库实例"""
database_url = os.getenv("DATABASE_URL")
if database_url:
candidate = SQLDatabase(database_url)
if candidate._db_type != "unknown":
logger.info(f"使用远程数据库: {candidate._db_type}")
return candidate
logger.warning("DATABASE_URL 格式不支持,将回退到文件系统数据库")
else:
from ..config import DATA_DIR
logger.info(f"使用文件系统数据库: {DATA_DIR}")
return FileSystemDatabase(DATA_DIR)
from ..config import DATA_DIR
logger.info(f"使用文件系统数据库: {DATA_DIR}")
return FileSystemDatabase(DATA_DIR)
# 全局数据库实例
_db_instance: Optional[DatabaseInterface] = None
_merge_completed: bool = False
async def get_database() -> DatabaseInterface:
"""获取数据库实例(单例模式),首次启动时自动合并本地账号到远程"""
global _db_instance, _merge_completed
if _db_instance is None:
candidate = create_database()
ok = await candidate.initialize()
if not ok:
from ..config import DATA_DIR
logger.warning(f"数据库初始化失败,将回退到文件系统数据库: {DATA_DIR}")
candidate = FileSystemDatabase(DATA_DIR)
await candidate.initialize()
_db_instance = candidate
if isinstance(_db_instance, SQLDatabase) and not _merge_completed:
await _auto_merge_local_to_remote()
_merge_completed = True
return _db_instance
async def _auto_merge_local_to_remote():
"""自动将本地账号合并到远程数据库"""
global _db_instance
try:
from ..config import DATA_DIR
local_db = FileSystemDatabase(DATA_DIR)
await local_db.initialize()
local_accounts = await local_db.load_accounts()
if not local_accounts:
logger.info("本地无账号,跳过合并")
return
remote_accounts = await _db_instance.load_accounts()
merged = _merge_accounts(local_accounts, remote_accounts)
if len(merged) > len(remote_accounts):
await _db_instance.save_accounts(merged)
new_count = len(merged) - len(remote_accounts)
logger.info(f"已将 {new_count} 个本地账号合并到远程数据库,共 {len(merged)} 个账号")
else:
logger.info(f"所有本地账号已存在于远程,无需合并")
except Exception as e:
logger.error(f"自动合并账号失败: {e}")
class SyncManager:
"""双向数据同步管理器"""
def __init__(self, local_db: FileSystemDatabase, remote_db: SQLDatabase):
self.local_db = local_db
self.remote_db = remote_db
self._local_hash: str = ""
self._remote_hash: str = ""
self._sync_lock = asyncio.Lock()
self._running = False
self._sync_task: Optional[asyncio.Task] = None
self._sync_interval = int(os.getenv("SYNC_INTERVAL", "30"))
self._callbacks: List[Callable] = []
def register_callback(self, callback: Callable):
"""注册同步完成回调"""
self._callbacks.append(callback)
async def _notify_callbacks(self):
"""通知所有回调"""
for callback in self._callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback()
else:
callback()
except Exception as e:
logger.error(f"同步回调执行失败: {e}")
async def start(self):
"""启动同步"""
if self._running:
return
self._running = True
await self._initial_sync()
self._sync_task = asyncio.create_task(self._sync_loop())
logger.info(f"双向同步已启动,间隔 {self._sync_interval} 秒")
async def stop(self):
"""停止同步"""
self._running = False
if self._sync_task:
self._sync_task.cancel()
try:
await self._sync_task
except asyncio.CancelledError:
pass
logger.info("双向同步已停止")
async def _initial_sync(self):
"""初始同步:合并两端数据"""
async with self._sync_lock:
try:
local_accounts = await self.local_db.load_accounts()
remote_accounts = await self.remote_db.load_accounts()
merged = _merge_accounts_bidirectional(local_accounts, remote_accounts)
await self.local_db.save_accounts(merged)
await self.remote_db.save_accounts(merged)
self._local_hash = await self.local_db.get_config_hash()
self._remote_hash = await self.remote_db.get_config_hash()
logger.info(f"初始同步完成,共 {len(merged)} 个账号")
await self._notify_callbacks()
except Exception as e:
logger.error(f"初始同步失败: {e}")
async def _sync_loop(self):
"""同步循环"""
while self._running:
try:
await asyncio.sleep(self._sync_interval)
await self._check_and_sync()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"同步循环异常: {e}")
async def _check_and_sync(self):
"""检查并同步变更"""
async with self._sync_lock:
try:
current_local_hash = await self.local_db.get_config_hash()
current_remote_hash = await self.remote_db.get_config_hash()
local_changed = current_local_hash != self._local_hash
remote_changed = current_remote_hash != self._remote_hash
if not local_changed and not remote_changed:
return
local_accounts = await self.local_db.load_accounts()
remote_accounts = await self.remote_db.load_accounts()
if local_changed and remote_changed:
logger.info("检测到本地和远程都有变更,执行合并")
merged = _merge_accounts_bidirectional(local_accounts, remote_accounts)
await self.local_db.save_accounts(merged)
await self.remote_db.save_accounts(merged)
elif local_changed:
logger.info("检测到本地变更,同步到远程")
merged = _merge_accounts_bidirectional(local_accounts, remote_accounts)
await self.remote_db.save_accounts(merged)
await self.local_db.save_accounts(merged)
elif remote_changed:
logger.info("检测到远程变更,同步到本地")
merged = _merge_accounts_bidirectional(local_accounts, remote_accounts)
await self.local_db.save_accounts(merged)
await self.remote_db.save_accounts(merged)
self._local_hash = await self.local_db.get_config_hash()
self._remote_hash = await self.remote_db.get_config_hash()
await self._notify_callbacks()
except Exception as e:
logger.error(f"同步检查失败: {e}")
async def force_sync(self, source: str = "merge"):
"""强制同步
Args:
source: 同步源 - "local" (本地覆盖远程), "remote" (远程覆盖本地), "merge" (合并)
"""
async with self._sync_lock:
try:
if source == "local":
accounts = await self.local_db.load_accounts()
await self.remote_db.save_accounts(accounts)
logger.info(f"强制同步:本地 -> 远程,{len(accounts)} 个账号")
elif source == "remote":
accounts = await self.remote_db.load_accounts()
await self.local_db.save_accounts(accounts)
logger.info(f"强制同步:远程 -> 本地,{len(accounts)} 个账号")
else:
local_accounts = await self.local_db.load_accounts()
remote_accounts = await self.remote_db.load_accounts()
merged = _merge_accounts_bidirectional(local_accounts, remote_accounts)
await self.local_db.save_accounts(merged)
await self.remote_db.save_accounts(merged)
logger.info(f"强制同步:双向合并,{len(merged)} 个账号")
self._local_hash = await self.local_db.get_config_hash()
self._remote_hash = await self.remote_db.get_config_hash()
await self._notify_callbacks()
return True
except Exception as e:
logger.error(f"强制同步失败: {e}")
return False
async def get_sync_status(self) -> Dict[str, Any]:
"""获取同步状态"""
try:
local_accounts = await self.local_db.load_accounts()
remote_accounts = await self.remote_db.load_accounts()
current_local_hash = await self.local_db.get_config_hash()
current_remote_hash = await self.remote_db.get_config_hash()
return {
"enabled": True,
"running": self._running,
"sync_interval": self._sync_interval,
"local_accounts": len(local_accounts),
"remote_accounts": len(remote_accounts),
"local_changed": current_local_hash != self._local_hash,
"remote_changed": current_remote_hash != self._remote_hash,
"in_sync": current_local_hash == current_remote_hash,
}
except Exception as e:
return {
"enabled": True,
"running": self._running,
"error": str(e),
}
def _merge_accounts_bidirectional(
local_accounts: List[Dict[str, Any]],
remote_accounts: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""双向合并账号,基于更新时间或合并策略"""
accounts_map: Dict[str, Dict[str, Any]] = {}
for account in local_accounts:
key = _get_account_key(account)
account["_source"] = "local"
accounts_map[key] = account
for account in remote_accounts:
key = _get_account_key(account)
if key in accounts_map:
local_acc = accounts_map[key]
local_updated = local_acc.get("updated_at", 0)
remote_updated = account.get("updated_at", 0)
if remote_updated > local_updated:
account["_source"] = "remote"
accounts_map[key] = account
else:
account["_source"] = "remote"
accounts_map[key] = account
merged = []
for account in accounts_map.values():
clean_account = {k: v for k, v in account.items() if not k.startswith("_")}
merged.append(clean_account)
return merged
_sync_manager: Optional[SyncManager] = None
async def get_sync_manager() -> Optional[SyncManager]:
"""获取同步管理器实例"""
return _sync_manager
async def start_sync_manager():
"""启动同步管理器"""
global _sync_manager
database_url = os.getenv("DATABASE_URL")
if not database_url:
logger.info("未配置 DATABASE_URL,跳过同步管理器")
return
from ..config import DATA_DIR
local_db = FileSystemDatabase(DATA_DIR)
await local_db.initialize()
remote_db = SQLDatabase(database_url)
if remote_db._db_type == "unknown":
logger.warning("DATABASE_URL 格式不支持,跳过同步管理器")
return
ok = await remote_db.initialize()
if not ok:
logger.warning("远程数据库初始化失败,跳过同步管理器")
return
_sync_manager = SyncManager(local_db, remote_db)
await _sync_manager.start()
async def stop_sync_manager():
"""停止同步管理器"""
global _sync_manager
if _sync_manager:
await _sync_manager.stop()
_sync_manager = None