# database_sql.py # ========================================== # 🛡️ 稳定性优化:SQL 数据库连接模块 # ========================================== import os import time import logging from functools import wraps from sqlalchemy import create_engine, text from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import QueuePool, NullPool from sqlalchemy.exc import OperationalError, InterfaceError, DBAPIError from models_sql import Base logger = logging.getLogger("ComfyUI-Ranking.Database") # 核心:优先读取环境变量中的 PostgreSQL 数据库连接 SQLALCHEMY_DATABASE_URL = os.environ.get("DATABASE_URL", "sqlite:////tmp/comfy_financial.db") # 🚀 P2优化:连接池配置支持环境变量 POOL_SIZE = int(os.environ.get("DB_POOL_SIZE", "5")) POOL_OVERFLOW = int(os.environ.get("DB_POOL_OVERFLOW", "10")) # 🚀 P1性能优化:根据数据库类型配置连接池 if "sqlite" in SQLALCHEMY_DATABASE_URL: # SQLite:使用空连接池,单线程模式 connect_args = {"check_same_thread": False} engine = create_engine( SQLALCHEMY_DATABASE_URL, connect_args=connect_args, poolclass=NullPool # SQLite 不需要连接池 ) logger.info(f"数据库引擎初始化完成 (SQLite)") else: # PostgreSQL/MySQL:配置连接池参数 connect_args = {} if "postgresql" in SQLALCHEMY_DATABASE_URL or "postgres" in SQLALCHEMY_DATABASE_URL: sslmode = os.environ.get("DB_SSLMODE", "require") connect_args["sslmode"] = sslmode connect_args["connect_timeout"] = int(os.environ.get("DB_CONNECT_TIMEOUT", "10")) logger.info(f"PostgreSQL SSL 配置: sslmode={sslmode}, connect_timeout={connect_args['connect_timeout']}") engine = create_engine( SQLALCHEMY_DATABASE_URL, poolclass=QueuePool, pool_size=POOL_SIZE, # 核心连接数(支持环境变量配置) max_overflow=POOL_OVERFLOW, # 超出 pool_size 后可创建的最大连接数(支持环境变量配置) pool_timeout=30, # 获取连接超时(秒) pool_recycle=1800, # 连接回收时间(30分钟) pool_pre_ping=True, # 使用前检测连接有效性 connect_args=connect_args ) logger.info(f"数据库引擎初始化完成 (PostgreSQL/MySQL) - 连接池配置: pool_size={POOL_SIZE}, max_overflow={POOL_OVERFLOW}") SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) # ========================================== # 🛡️ 稳定性优化:数据库操作重试装饰器 # ========================================== def db_retry(max_retries: int = 3, delay: float = 0.5): """ 数据库操作重试装饰器 用法: @db_retry(max_retries=3) def my_db_function(db): ... """ def decorator(func): @wraps(func) def wrapper(*args, **kwargs): last_error = None for attempt in range(max_retries): try: return func(*args, **kwargs) except (OperationalError, InterfaceError, DBAPIError) as e: last_error = e if attempt < max_retries - 1: wait_time = delay * (2 ** attempt) # 指数退避 logger.warning(f"DB 操作失败 (第{attempt+1}次):{e}, {wait_time}秒后重试") time.sleep(wait_time) else: logger.error(f"DB 操作失败 (重试{max_retries}次后放弃): {e}") raise last_error return wrapper return decorator def init_sql_db(): """初始化数据库,包含重试机制""" for attempt in range(3): try: Base.metadata.create_all(bind=engine) # 🔄 P7后悔模式:自动迁移新增字段 _auto_migrate_p7_fields() logger.info("数据库初始化成功") return except Exception as e: if attempt < 2: logger.warning(f"数据库初始化失败 (第{attempt+1}次): {e}") time.sleep(2 ** attempt) else: logger.error(f"数据库初始化失败 (已重试3次): {e}") raise def _auto_migrate_p7_fields(): """ 🔄 P7后悔模式:自动迁移新增字段 检查并添加 ownerships 表和 transactions 表的新字段 """ from sqlalchemy import inspect try: inspector = inspect(engine) # ========== 1. 迁移 ownerships 表 ========== if 'ownerships' in inspector.get_table_names(): columns = [col['name'] for col in inspector.get_columns('ownerships')] with engine.connect() as conn: if 'price_paid' not in columns: if 'sqlite' in SQLALCHEMY_DATABASE_URL: conn.execute(text("ALTER TABLE ownerships ADD COLUMN price_paid INTEGER DEFAULT 0")) else: conn.execute(text("ALTER TABLE ownerships ADD COLUMN price_paid INTEGER DEFAULT 0")) logger.info("迁移完成: 添加 ownerships.price_paid 字段") if 'is_refunded' not in columns: if 'sqlite' in SQLALCHEMY_DATABASE_URL: conn.execute(text("ALTER TABLE ownerships ADD COLUMN is_refunded BOOLEAN DEFAULT 0")) else: conn.execute(text("ALTER TABLE ownerships ADD COLUMN is_refunded BOOLEAN DEFAULT FALSE")) logger.info("迁移完成: 添加 ownerships.is_refunded 字段") if 'refunded_at' not in columns: conn.execute(text("ALTER TABLE ownerships ADD COLUMN refunded_at TIMESTAMP")) logger.info("迁移完成: 添加 ownerships.refunded_at 字段") conn.commit() # ========== 2. 迁移 wallets 表(新增 task_balance 字段) ========== if 'wallets' in inspector.get_table_names(): columns = [col['name'] for col in inspector.get_columns('wallets')] with engine.connect() as conn: if 'task_balance' not in columns: if 'sqlite' in SQLALCHEMY_DATABASE_URL: conn.execute(text("ALTER TABLE wallets ADD COLUMN task_balance INTEGER DEFAULT 0")) else: conn.execute(text("ALTER TABLE wallets ADD COLUMN task_balance INTEGER DEFAULT 0")) logger.info("[DB Migration] 添加列 wallets.task_balance") conn.commit() # ========== 3. 迁移 transactions 表(提现相关新字段) ========== if 'transactions' in inspector.get_table_names(): columns = [col['name'] for col in inspector.get_columns('transactions')] with engine.connect() as conn: # 定义 transactions 表的新列 new_columns = { 'alipay_account': 'VARCHAR', 'real_name': 'VARCHAR', 'withdraw_status': 'VARCHAR', 'payment_order_id': 'VARCHAR', 'net_amount': 'INTEGER', 'description': 'VARCHAR', 'item_title': 'VARCHAR', 'item_type': 'VARCHAR', 'related_user_name': 'VARCHAR', } for col_name, col_type in new_columns.items(): if col_name not in columns: # VARCHAR 类型添加 DEFAULT NULL 避免 PostgreSQL NOT NULL 冲突 if col_type == 'VARCHAR': conn.execute(text(f"ALTER TABLE transactions ADD COLUMN {col_name} {col_type} DEFAULT NULL")) else: conn.execute(text(f"ALTER TABLE transactions ADD COLUMN {col_name} {col_type}")) logger.info(f"[DB Migration] 添加列 transactions.{col_name}") conn.commit() # ========== 4. 回填旧提现记录数据 ========== # 将旧 WITHDRAW 记录的 withdraw_status 从 NULL 更新为 "completed" result = conn.execute(text( "UPDATE transactions SET withdraw_status = 'completed' " "WHERE tx_type = 'WITHDRAW' AND withdraw_status IS NULL" )) if result.rowcount > 0: logger.info(f"[DB Migration] 回填旧提现记录: 更新 {result.rowcount} 条记录的 withdraw_status 为 'completed'") # 将旧记录的 net_amount 从 NULL 更新为 ABS(amount) result = conn.execute(text( "UPDATE transactions SET net_amount = ABS(amount) " "WHERE tx_type = 'WITHDRAW' AND net_amount IS NULL" )) if result.rowcount > 0: logger.info(f"[DB Migration] 回填旧提现记录: 更新 {result.rowcount} 条记录的 net_amount 为 ABS(amount)") conn.commit() except Exception as e: logger.warning(f"字段迁移跳过 (可能已存在): {e}") def get_db(): """获取数据库会话,带连接有效性检测""" db = SessionLocal() try: yield db except (OperationalError, InterfaceError) as e: logger.error(f"数据库连接错误: {e}") db.rollback() raise finally: db.close() def check_db_connection() -> bool: """ 检查数据库连接是否正常 用于健康检查接口 """ try: db = SessionLocal() db.execute(text("SELECT 1")) db.close() return True except Exception as e: logger.error(f"数据库连接检查失败: {e}") return False