| from __future__ import annotations |
|
|
| import json |
| from typing import Any |
|
|
| from sqlalchemy import Column, String, Text, create_engine, Integer, text |
| from sqlalchemy.ext.declarative import declarative_base |
| from sqlalchemy.orm import sessionmaker |
|
|
| from services.storage.base import StorageBackend |
|
|
| Base = declarative_base() |
|
|
|
|
| class AccountModel(Base): |
| """账号数据模型""" |
| __tablename__ = "accounts" |
|
|
| id = Column(Integer, primary_key=True, autoincrement=True) |
| access_token = Column(String(2048), unique=True, nullable=False, index=True) |
| data = Column(Text, nullable=False) |
|
|
|
|
| class AuthKeyModel(Base): |
| """鉴权密钥数据模型""" |
| __tablename__ = "auth_keys" |
|
|
| id = Column(Integer, primary_key=True, autoincrement=True) |
| key_id = Column(String(255), unique=True, nullable=False, index=True) |
| data = Column(Text, nullable=False) |
|
|
|
|
| class DatabaseStorageBackend(StorageBackend): |
| """数据库存储后端(支持 SQLite、PostgreSQL、MySQL 等)""" |
|
|
| def __init__(self, database_url: str): |
| self.database_url = database_url |
| self.engine = create_engine( |
| database_url, |
| pool_pre_ping=True, |
| pool_recycle=3600, |
| ) |
| Base.metadata.create_all(self.engine) |
| self.Session = sessionmaker(bind=self.engine) |
|
|
| def load_accounts(self) -> list[dict[str, Any]]: |
| """从数据库加载账号数据""" |
| session = self.Session() |
| try: |
| accounts = [] |
| for row in session.query(AccountModel).all(): |
| try: |
| account_data = json.loads(row.data) |
| if isinstance(account_data, dict): |
| accounts.append(account_data) |
| except json.JSONDecodeError: |
| continue |
| return accounts |
| finally: |
| session.close() |
|
|
| def save_accounts(self, accounts: list[dict[str, Any]]) -> None: |
| """保存账号数据到数据库""" |
| self._save_rows(AccountModel, accounts, "access_token") |
|
|
| def load_auth_keys(self) -> list[dict[str, Any]]: |
| """从数据库加载鉴权密钥数据""" |
| return self._load_rows(AuthKeyModel) |
|
|
| def save_auth_keys(self, auth_keys: list[dict[str, Any]]) -> None: |
| """保存鉴权密钥数据到数据库""" |
| self._save_rows(AuthKeyModel, auth_keys, "id", "key_id") |
|
|
| def _load_rows(self, model: type[AccountModel] | type[AuthKeyModel]) -> list[dict[str, Any]]: |
| session = self.Session() |
| try: |
| items = [] |
| for row in session.query(model).all(): |
| try: |
| item_data = json.loads(row.data) |
| if isinstance(item_data, dict): |
| items.append(item_data) |
| except json.JSONDecodeError: |
| continue |
| return items |
| finally: |
| session.close() |
|
|
| def _save_rows( |
| self, |
| model: type[AccountModel] | type[AuthKeyModel], |
| items: list[dict[str, Any]], |
| source_key: str, |
| target_key: str | None = None, |
| ) -> None: |
| session = self.Session() |
| try: |
| session.query(model).delete() |
| for item in items: |
| if not isinstance(item, dict): |
| continue |
| key_value = str(item.get(source_key) or "").strip() |
| if not key_value: |
| continue |
| session.add( |
| model( |
| **{target_key or source_key: key_value}, |
| data=json.dumps(item, ensure_ascii=False), |
| ) |
| ) |
| session.commit() |
| except Exception as e: |
| session.rollback() |
| raise e |
| finally: |
| session.close() |
|
|
| def health_check(self) -> dict[str, Any]: |
| """健康检查""" |
| try: |
| session = self.Session() |
| try: |
| |
| session.execute(text("SELECT 1")) |
| count = session.query(AccountModel).count() |
| auth_key_count = session.query(AuthKeyModel).count() |
| return { |
| "status": "healthy", |
| "backend": "database", |
| "database_url": self._mask_password(self.database_url), |
| "account_count": count, |
| "auth_key_count": auth_key_count, |
| } |
| finally: |
| session.close() |
| except Exception as e: |
| return { |
| "status": "unhealthy", |
| "backend": "database", |
| "error": str(e), |
| } |
|
|
| def get_backend_info(self) -> dict[str, Any]: |
| """获取存储后端信息""" |
| db_type = "unknown" |
| if "sqlite" in self.database_url: |
| db_type = "sqlite" |
| elif "postgresql" in self.database_url or "postgres" in self.database_url: |
| db_type = "postgresql" |
| elif "mysql" in self.database_url: |
| db_type = "mysql" |
| |
| return { |
| "type": "database", |
| "db_type": db_type, |
| "description": f"数据库存储 ({db_type})", |
| "database_url": self._mask_password(self.database_url), |
| } |
|
|
| @staticmethod |
| def _mask_password(url: str) -> str: |
| """隐藏数据库连接字符串中的密码""" |
| if "://" not in url: |
| return url |
| try: |
| protocol, rest = url.split("://", 1) |
| if "@" in rest: |
| credentials, host = rest.split("@", 1) |
| if ":" in credentials: |
| username, _ = credentials.split(":", 1) |
| return f"{protocol}://{username}:****@{host}" |
| return url |
| except Exception: |
| return url |
|
|