| """存储抽象层 - 支持文件、MySQL和Redis存储""" |
|
|
| import os |
| import orjson |
| import toml |
| import asyncio |
| import warnings |
| import aiofiles |
| from pathlib import Path |
| from typing import Dict, Any, Optional, Literal |
| from abc import ABC, abstractmethod |
| from urllib.parse import urlparse, unquote |
|
|
| from app.core.logger import logger |
|
|
|
|
| StorageMode = Literal["file", "mysql", "redis"] |
|
|
|
|
| class BaseStorage(ABC): |
| """存储基类""" |
|
|
| @abstractmethod |
| async def init_db(self) -> None: |
| """初始化数据库""" |
| pass |
|
|
| @abstractmethod |
| async def load_tokens(self) -> Dict[str, Any]: |
| """加载token数据""" |
| pass |
|
|
| @abstractmethod |
| async def save_tokens(self, data: Dict[str, Any]) -> None: |
| """保存token数据""" |
| pass |
|
|
| @abstractmethod |
| async def load_config(self) -> Dict[str, Any]: |
| """加载配置数据""" |
| pass |
|
|
| @abstractmethod |
| async def save_config(self, data: Dict[str, Any]) -> None: |
| """保存配置数据""" |
| pass |
|
|
|
|
| class FileStorage(BaseStorage): |
| """文件存储""" |
|
|
| def __init__(self, data_dir: Path): |
| self.data_dir = data_dir |
| self.token_file = data_dir / "token.json" |
| self.config_file = data_dir / "setting.toml" |
| self._token_lock = asyncio.Lock() |
| self._config_lock = asyncio.Lock() |
|
|
| async def init_db(self) -> None: |
| """初始化文件存储""" |
| self.data_dir.mkdir(parents=True, exist_ok=True) |
|
|
| if not self.token_file.exists(): |
| await self._write(self.token_file, orjson.dumps({"sso": {}, "ssoSuper": {}}, option=orjson.OPT_INDENT_2).decode()) |
| logger.info("[Storage] 创建token文件") |
|
|
| if not self.config_file.exists(): |
| default = { |
| "global": {"api_keys": [], "admin_username": "admin", "admin_password": "admin"}, |
| "grok": {"proxy_url": "", "cf_clearance": "", "x_statsig_id": ""} |
| } |
| await self._write(self.config_file, toml.dumps(default)) |
| logger.info("[Storage] 创建配置文件") |
|
|
| async def _read(self, path: Path) -> str: |
| """读取文件""" |
| async with aiofiles.open(path, "r", encoding="utf-8") as f: |
| return await f.read() |
|
|
| async def _write(self, path: Path, content: str) -> None: |
| """写入文件""" |
| async with aiofiles.open(path, "w", encoding="utf-8") as f: |
| await f.write(content) |
|
|
| async def _load_json(self, path: Path, default: Dict, lock: asyncio.Lock) -> Dict[str, Any]: |
| """加载JSON""" |
| try: |
| async with lock: |
| if not path.exists(): |
| return default |
| return orjson.loads(await self._read(path)) |
| except Exception as e: |
| logger.error(f"[Storage] 加载{path.name}失败: {e}") |
| return default |
|
|
| async def _save_json(self, path: Path, data: Dict, lock: asyncio.Lock) -> None: |
| """保存JSON""" |
| try: |
| async with lock: |
| await self._write(path, orjson.dumps(data, option=orjson.OPT_INDENT_2).decode()) |
| except Exception as e: |
| logger.error(f"[Storage] 保存{path.name}失败: {e}") |
| raise |
|
|
| async def _load_toml(self, path: Path, default: Dict, lock: asyncio.Lock) -> Dict[str, Any]: |
| """加载TOML""" |
| try: |
| async with lock: |
| if not path.exists(): |
| return default |
| return toml.loads(await self._read(path)) |
| except Exception as e: |
| logger.error(f"[Storage] 加载{path.name}失败: {e}") |
| return default |
|
|
| async def _save_toml(self, path: Path, data: Dict, lock: asyncio.Lock) -> None: |
| """保存TOML""" |
| try: |
| async with lock: |
| await self._write(path, toml.dumps(data)) |
| except Exception as e: |
| logger.error(f"[Storage] 保存{path.name}失败: {e}") |
| raise |
|
|
| async def load_tokens(self) -> Dict[str, Any]: |
| """加载token""" |
| return await self._load_json(self.token_file, {"sso": {}, "ssoSuper": {}}, self._token_lock) |
|
|
| async def save_tokens(self, data: Dict[str, Any]) -> None: |
| """保存token""" |
| await self._save_json(self.token_file, data, self._token_lock) |
|
|
| async def load_config(self) -> Dict[str, Any]: |
| """加载配置""" |
| return await self._load_toml(self.config_file, {"global": {}, "grok": {}}, self._config_lock) |
|
|
| async def save_config(self, data: Dict[str, Any]) -> None: |
| """保存配置""" |
| await self._save_toml(self.config_file, data, self._config_lock) |
|
|
|
|
| class MysqlStorage(BaseStorage): |
| """MySQL存储""" |
|
|
| def __init__(self, database_url: str, data_dir: Path): |
| self.database_url = database_url |
| self.data_dir = data_dir |
| self._pool = None |
| self._file = FileStorage(data_dir) |
|
|
| async def init_db(self) -> None: |
| """初始化MySQL""" |
| try: |
| import aiomysql |
| parsed = self._parse_url(self.database_url) |
| logger.info(f"[Storage] MySQL: {parsed['user']}@{parsed['host']}:{parsed['port']}/{parsed['db']}") |
|
|
| await self._create_db(parsed) |
| self._pool = await aiomysql.create_pool( |
| host=parsed['host'], port=parsed['port'], user=parsed['user'], |
| password=parsed['password'], db=parsed['db'], charset="utf8mb4", |
| autocommit=True, maxsize=10 |
| ) |
| await self._create_tables() |
| await self._file.init_db() |
| await self._sync_data() |
|
|
| except ImportError: |
| raise Exception("aiomysql未安装") |
| except Exception as e: |
| logger.error(f"[Storage] MySQL初始化失败: {e}") |
| raise |
|
|
| def _parse_url(self, url: str) -> Dict[str, Any]: |
| """解析URL""" |
| p = urlparse(url) |
| return { |
| 'user': unquote(p.username) if p.username else "", |
| 'password': unquote(p.password) if p.password else "", |
| 'host': p.hostname, |
| 'port': p.port or 3306, |
| 'db': p.path[1:] if p.path else "grok2api" |
| } |
|
|
| async def _create_db(self, parsed: Dict) -> None: |
| """创建数据库""" |
| import aiomysql |
| pool = await aiomysql.create_pool( |
| host=parsed['host'], port=parsed['port'], user=parsed['user'], |
| password=parsed['password'], charset="utf8mb4", autocommit=True, maxsize=1 |
| ) |
|
|
| try: |
| async with pool.acquire() as conn: |
| async with conn.cursor() as cursor: |
| with warnings.catch_warnings(): |
| warnings.filterwarnings('ignore', message='.*database exists') |
| await cursor.execute( |
| f"CREATE DATABASE IF NOT EXISTS `{parsed['db']}` " |
| f"CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci" |
| ) |
| logger.info(f"[Storage] 数据库 '{parsed['db']}' 就绪") |
| finally: |
| pool.close() |
| await pool.wait_closed() |
|
|
| async def _create_tables(self) -> None: |
| """创建表""" |
| tables = { |
| "grok_tokens": """ |
| CREATE TABLE IF NOT EXISTS grok_tokens ( |
| id INT AUTO_INCREMENT PRIMARY KEY, |
| data JSON NOT NULL, |
| updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP |
| ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 |
| """, |
| "grok_settings": """ |
| CREATE TABLE IF NOT EXISTS grok_settings ( |
| id INT AUTO_INCREMENT PRIMARY KEY, |
| data JSON NOT NULL, |
| updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP |
| ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 |
| """ |
| } |
|
|
| async with self._pool.acquire() as conn: |
| async with conn.cursor() as cursor: |
| with warnings.catch_warnings(): |
| warnings.filterwarnings('ignore', message='.*already exists') |
| for sql in tables.values(): |
| await cursor.execute(sql) |
| logger.info("[Storage] MySQL表就绪") |
|
|
| async def _sync_data(self) -> None: |
| """同步数据""" |
| try: |
| for table, key in [("grok_tokens", "sso"), ("grok_settings", "global")]: |
| data = await self._load_db(table) |
| if data: |
| if table == "grok_tokens": |
| await self._file.save_tokens(data) |
| else: |
| await self._file.save_config(data) |
| logger.info(f"[Storage] {table.split('_')[1]}数据已从DB同步") |
| else: |
| file_data = await (self._file.load_tokens() if table == "grok_tokens" else self._file.load_config()) |
| if file_data.get(key) or (table == "grok_tokens" and file_data.get("ssoSuper")): |
| await self._save_db(table, file_data) |
| logger.info(f"[Storage] {table.split('_')[1]}数据已初始化到DB") |
| except Exception as e: |
| logger.warning(f"[Storage] 同步失败: {e}") |
|
|
| async def _load_db(self, table: str) -> Optional[Dict]: |
| """从DB加载""" |
| try: |
| async with self._pool.acquire() as conn: |
| async with conn.cursor() as cursor: |
| await cursor.execute(f"SELECT data FROM {table} ORDER BY id DESC LIMIT 1") |
| result = await cursor.fetchone() |
| return orjson.loads(result[0]) if result else None |
| except Exception as e: |
| logger.error(f"[Storage] 加载{table}失败: {e}") |
| return None |
|
|
| async def _save_db(self, table: str, data: Dict) -> None: |
| """保存到DB""" |
| try: |
| async with self._pool.acquire() as conn: |
| async with conn.cursor() as cursor: |
| json_data = orjson.dumps(data).decode() |
| await cursor.execute(f"SELECT id FROM {table} ORDER BY id DESC LIMIT 1") |
| result = await cursor.fetchone() |
|
|
| if result: |
| await cursor.execute(f"UPDATE {table} SET data = %s WHERE id = %s", (json_data, result[0])) |
| else: |
| await cursor.execute(f"INSERT INTO {table} (data) VALUES (%s)", (json_data,)) |
| except Exception as e: |
| logger.error(f"[Storage] 保存{table}失败: {e}") |
| raise |
|
|
| async def load_tokens(self) -> Dict[str, Any]: |
| """加载token""" |
| return await self._file.load_tokens() |
|
|
| async def save_tokens(self, data: Dict[str, Any]) -> None: |
| """保存token""" |
| await self._file.save_tokens(data) |
| await self._save_db("grok_tokens", data) |
|
|
| async def load_config(self) -> Dict[str, Any]: |
| """加载配置""" |
| return await self._file.load_config() |
|
|
| async def save_config(self, data: Dict[str, Any]) -> None: |
| """保存配置""" |
| await self._file.save_config(data) |
| await self._save_db("grok_settings", data) |
|
|
| async def close(self) -> None: |
| """关闭连接""" |
| if self._pool: |
| self._pool.close() |
| await self._pool.wait_closed() |
| logger.info("[Storage] MySQL已关闭") |
|
|
|
|
| class RedisStorage(BaseStorage): |
| """Redis存储""" |
|
|
| def __init__(self, redis_url: str, data_dir: Path): |
| self.redis_url = redis_url |
| self.data_dir = data_dir |
| self._redis = None |
| self._file = FileStorage(data_dir) |
|
|
| async def init_db(self) -> None: |
| """初始化Redis""" |
| try: |
| import redis.asyncio as aioredis |
| parsed = urlparse(self.redis_url) |
| db = int(parsed.path.lstrip('/')) if parsed.path and parsed.path != '/' else 0 |
| logger.info(f"[Storage] Redis: {parsed.hostname}:{parsed.port or 6379}/{db}") |
|
|
| self._redis = aioredis.Redis.from_url( |
| self.redis_url, encoding="utf-8", decode_responses=True |
| ) |
|
|
| await self._redis.ping() |
| logger.info(f"[Storage] Redis连接成功") |
|
|
| await self._file.init_db() |
| await self._sync_data() |
|
|
| except ImportError: |
| raise Exception("redis未安装") |
| except Exception as e: |
| logger.error(f"[Storage] Redis初始化失败: {e}") |
| raise |
|
|
| async def _sync_data(self) -> None: |
| """同步数据""" |
| try: |
| for key, file_func, key_name in [ |
| ("grok:tokens", self._file.load_tokens, "sso"), |
| ("grok:settings", self._file.load_config, "global") |
| ]: |
| data = await self._redis.get(key) |
| if data: |
| parsed = orjson.loads(data) |
| if key == "grok:tokens": |
| await self._file.save_tokens(parsed) |
| else: |
| await self._file.save_config(parsed) |
| logger.info(f"[Storage] {key.split(':')[1]}数据已从Redis同步") |
| else: |
| file_data = await file_func() |
| if file_data.get(key_name) or (key == "grok:tokens" and file_data.get("ssoSuper")): |
| await self._redis.set(key, orjson.dumps(file_data).decode()) |
| logger.info(f"[Storage] {key.split(':')[1]}数据已初始化到Redis") |
| except Exception as e: |
| logger.warning(f"[Storage] 同步失败: {e}") |
|
|
| async def _save_redis(self, key: str, data: Dict) -> None: |
| """保存到Redis""" |
| try: |
| await self._redis.set(key, orjson.dumps(data).decode()) |
| except Exception as e: |
| logger.error(f"[Storage] 保存Redis失败: {e}") |
| raise |
|
|
| async def load_tokens(self) -> Dict[str, Any]: |
| """加载token""" |
| return await self._file.load_tokens() |
|
|
| async def save_tokens(self, data: Dict[str, Any]) -> None: |
| """保存token""" |
| await self._file.save_tokens(data) |
| await self._save_redis("grok:tokens", data) |
|
|
| async def load_config(self) -> Dict[str, Any]: |
| """加载配置""" |
| return await self._file.load_config() |
|
|
| async def save_config(self, data: Dict[str, Any]) -> None: |
| """保存配置""" |
| await self._file.save_config(data) |
| await self._save_redis("grok:settings", data) |
|
|
| async def close(self) -> None: |
| """关闭连接""" |
| if self._redis: |
| await self._redis.close() |
| logger.info("[Storage] Redis已关闭") |
|
|
|
|
| class StorageManager: |
| """存储管理器(单例)""" |
|
|
| _instance: Optional['StorageManager'] = None |
| _storage: Optional[BaseStorage] = None |
| _initialized: bool = False |
|
|
| def __new__(cls): |
| if cls._instance is None: |
| cls._instance = super().__new__(cls) |
| return cls._instance |
|
|
| async def init(self) -> None: |
| """初始化存储""" |
| if self._initialized: |
| return |
|
|
| mode = os.getenv("STORAGE_MODE", "file").lower() |
| url = os.getenv("DATABASE_URL", "") |
| data_dir = Path(__file__).parents[2] / "data" |
|
|
| classes = {"mysql": MysqlStorage, "redis": RedisStorage, "file": FileStorage} |
|
|
| if mode in ("mysql", "redis") and not url: |
| raise ValueError(f"{mode.upper()}模式需要DATABASE_URL") |
|
|
| storage_class = classes.get(mode, FileStorage) |
| self._storage = storage_class(url, data_dir) if mode != "file" else storage_class(data_dir) |
|
|
| await self._storage.init_db() |
| self._initialized = True |
| logger.info(f"[Storage] 使用{mode}模式") |
|
|
| def get_storage(self) -> BaseStorage: |
| """获取存储实例""" |
| if not self._initialized or not self._storage: |
| raise RuntimeError("StorageManager未初始化") |
| return self._storage |
|
|
| async def close(self) -> None: |
| """关闭存储""" |
| if self._storage and hasattr(self._storage, 'close'): |
| await self._storage.close() |
|
|
|
|
| |
| storage_manager = StorageManager() |
|
|