Spaces:
Running
Running
| # database_sql.py | |
| # ========================================== | |
| # 🛡️ 稳定性优化:SQL 数据库连接模块 | |
| # ========================================== | |
| import os | |
| import time | |
| import logging | |
| from functools import wraps | |
| from sqlalchemy import create_engine, text | |
| from sqlalchemy.orm import sessionmaker | |
| from sqlalchemy.pool import QueuePool, NullPool | |
| from sqlalchemy.exc import OperationalError, InterfaceError, DBAPIError | |
| from models_sql import Base | |
| logger = logging.getLogger("ComfyUI-Ranking.Database") | |
| # 核心:优先读取环境变量中的 PostgreSQL 数据库连接 | |
| SQLALCHEMY_DATABASE_URL = os.environ.get("DATABASE_URL", "sqlite:////tmp/comfy_financial.db") | |
| # 🚀 P2优化:连接池配置支持环境变量 | |
| POOL_SIZE = int(os.environ.get("DB_POOL_SIZE", "5")) | |
| POOL_OVERFLOW = int(os.environ.get("DB_POOL_OVERFLOW", "10")) | |
| # 🚀 P1性能优化:根据数据库类型配置连接池 | |
| if "sqlite" in SQLALCHEMY_DATABASE_URL: | |
| # SQLite:使用空连接池,单线程模式 | |
| connect_args = {"check_same_thread": False} | |
| engine = create_engine( | |
| SQLALCHEMY_DATABASE_URL, | |
| connect_args=connect_args, | |
| poolclass=NullPool # SQLite 不需要连接池 | |
| ) | |
| logger.info(f"数据库引擎初始化完成 (SQLite)") | |
| else: | |
| # PostgreSQL/MySQL:配置连接池参数 | |
| connect_args = {} | |
| if "postgresql" in SQLALCHEMY_DATABASE_URL or "postgres" in SQLALCHEMY_DATABASE_URL: | |
| sslmode = os.environ.get("DB_SSLMODE", "require") | |
| connect_args["sslmode"] = sslmode | |
| connect_args["connect_timeout"] = int(os.environ.get("DB_CONNECT_TIMEOUT", "10")) | |
| logger.info(f"PostgreSQL SSL 配置: sslmode={sslmode}, connect_timeout={connect_args['connect_timeout']}") | |
| engine = create_engine( | |
| SQLALCHEMY_DATABASE_URL, | |
| poolclass=QueuePool, | |
| pool_size=POOL_SIZE, # 核心连接数(支持环境变量配置) | |
| max_overflow=POOL_OVERFLOW, # 超出 pool_size 后可创建的最大连接数(支持环境变量配置) | |
| pool_timeout=30, # 获取连接超时(秒) | |
| pool_recycle=1800, # 连接回收时间(30分钟) | |
| pool_pre_ping=True, # 使用前检测连接有效性 | |
| connect_args=connect_args | |
| ) | |
| logger.info(f"数据库引擎初始化完成 (PostgreSQL/MySQL) - 连接池配置: pool_size={POOL_SIZE}, max_overflow={POOL_OVERFLOW}") | |
| SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) | |
| # ========================================== | |
| # 🛡️ 稳定性优化:数据库操作重试装饰器 | |
| # ========================================== | |
| def db_retry(max_retries: int = 3, delay: float = 0.5): | |
| """ | |
| 数据库操作重试装饰器 | |
| 用法: | |
| @db_retry(max_retries=3) | |
| def my_db_function(db): | |
| ... | |
| """ | |
| def decorator(func): | |
| def wrapper(*args, **kwargs): | |
| last_error = None | |
| for attempt in range(max_retries): | |
| try: | |
| return func(*args, **kwargs) | |
| except (OperationalError, InterfaceError, DBAPIError) as e: | |
| last_error = e | |
| if attempt < max_retries - 1: | |
| wait_time = delay * (2 ** attempt) # 指数退避 | |
| logger.warning(f"DB 操作失败 (第{attempt+1}次):{e}, {wait_time}秒后重试") | |
| time.sleep(wait_time) | |
| else: | |
| logger.error(f"DB 操作失败 (重试{max_retries}次后放弃): {e}") | |
| raise last_error | |
| return wrapper | |
| return decorator | |
| def init_sql_db(): | |
| """初始化数据库,包含重试机制""" | |
| for attempt in range(3): | |
| try: | |
| Base.metadata.create_all(bind=engine) | |
| # 🔄 P7后悔模式:自动迁移新增字段 | |
| _auto_migrate_p7_fields() | |
| logger.info("数据库初始化成功") | |
| return | |
| except Exception as e: | |
| if attempt < 2: | |
| logger.warning(f"数据库初始化失败 (第{attempt+1}次): {e}") | |
| time.sleep(2 ** attempt) | |
| else: | |
| logger.error(f"数据库初始化失败 (已重试3次): {e}") | |
| raise | |
| def _auto_migrate_p7_fields(): | |
| """ | |
| 🔄 P7后悔模式:自动迁移新增字段 | |
| 检查并添加 ownerships 表和 transactions 表的新字段 | |
| """ | |
| from sqlalchemy import inspect | |
| try: | |
| inspector = inspect(engine) | |
| # ========== 1. 迁移 ownerships 表 ========== | |
| if 'ownerships' in inspector.get_table_names(): | |
| columns = [col['name'] for col in inspector.get_columns('ownerships')] | |
| with engine.connect() as conn: | |
| if 'price_paid' not in columns: | |
| if 'sqlite' in SQLALCHEMY_DATABASE_URL: | |
| conn.execute(text("ALTER TABLE ownerships ADD COLUMN price_paid INTEGER DEFAULT 0")) | |
| else: | |
| conn.execute(text("ALTER TABLE ownerships ADD COLUMN price_paid INTEGER DEFAULT 0")) | |
| logger.info("迁移完成: 添加 ownerships.price_paid 字段") | |
| if 'is_refunded' not in columns: | |
| if 'sqlite' in SQLALCHEMY_DATABASE_URL: | |
| conn.execute(text("ALTER TABLE ownerships ADD COLUMN is_refunded BOOLEAN DEFAULT 0")) | |
| else: | |
| conn.execute(text("ALTER TABLE ownerships ADD COLUMN is_refunded BOOLEAN DEFAULT FALSE")) | |
| logger.info("迁移完成: 添加 ownerships.is_refunded 字段") | |
| if 'refunded_at' not in columns: | |
| conn.execute(text("ALTER TABLE ownerships ADD COLUMN refunded_at TIMESTAMP")) | |
| logger.info("迁移完成: 添加 ownerships.refunded_at 字段") | |
| conn.commit() | |
| # ========== 2. 迁移 wallets 表(新增 task_balance 字段) ========== | |
| if 'wallets' in inspector.get_table_names(): | |
| columns = [col['name'] for col in inspector.get_columns('wallets')] | |
| with engine.connect() as conn: | |
| if 'task_balance' not in columns: | |
| if 'sqlite' in SQLALCHEMY_DATABASE_URL: | |
| conn.execute(text("ALTER TABLE wallets ADD COLUMN task_balance INTEGER DEFAULT 0")) | |
| else: | |
| conn.execute(text("ALTER TABLE wallets ADD COLUMN task_balance INTEGER DEFAULT 0")) | |
| logger.info("[DB Migration] 添加列 wallets.task_balance") | |
| conn.commit() | |
| # ========== 3. 迁移 transactions 表(提现相关新字段) ========== | |
| if 'transactions' in inspector.get_table_names(): | |
| columns = [col['name'] for col in inspector.get_columns('transactions')] | |
| with engine.connect() as conn: | |
| # 定义 transactions 表的新列 | |
| new_columns = { | |
| 'alipay_account': 'VARCHAR', | |
| 'real_name': 'VARCHAR', | |
| 'withdraw_status': 'VARCHAR', | |
| 'payment_order_id': 'VARCHAR', | |
| 'net_amount': 'INTEGER', | |
| 'description': 'VARCHAR', | |
| 'item_title': 'VARCHAR', | |
| 'item_type': 'VARCHAR', | |
| 'related_user_name': 'VARCHAR', | |
| } | |
| for col_name, col_type in new_columns.items(): | |
| if col_name not in columns: | |
| # VARCHAR 类型添加 DEFAULT NULL 避免 PostgreSQL NOT NULL 冲突 | |
| if col_type == 'VARCHAR': | |
| conn.execute(text(f"ALTER TABLE transactions ADD COLUMN {col_name} {col_type} DEFAULT NULL")) | |
| else: | |
| conn.execute(text(f"ALTER TABLE transactions ADD COLUMN {col_name} {col_type}")) | |
| logger.info(f"[DB Migration] 添加列 transactions.{col_name}") | |
| conn.commit() | |
| # ========== 4. 回填旧提现记录数据 ========== | |
| # 将旧 WITHDRAW 记录的 withdraw_status 从 NULL 更新为 "completed" | |
| result = conn.execute(text( | |
| "UPDATE transactions SET withdraw_status = 'completed' " | |
| "WHERE tx_type = 'WITHDRAW' AND withdraw_status IS NULL" | |
| )) | |
| if result.rowcount > 0: | |
| logger.info(f"[DB Migration] 回填旧提现记录: 更新 {result.rowcount} 条记录的 withdraw_status 为 'completed'") | |
| # 将旧记录的 net_amount 从 NULL 更新为 ABS(amount) | |
| result = conn.execute(text( | |
| "UPDATE transactions SET net_amount = ABS(amount) " | |
| "WHERE tx_type = 'WITHDRAW' AND net_amount IS NULL" | |
| )) | |
| if result.rowcount > 0: | |
| logger.info(f"[DB Migration] 回填旧提现记录: 更新 {result.rowcount} 条记录的 net_amount 为 ABS(amount)") | |
| conn.commit() | |
| except Exception as e: | |
| logger.warning(f"字段迁移跳过 (可能已存在): {e}") | |
| def get_db(): | |
| """获取数据库会话,带连接有效性检测""" | |
| db = SessionLocal() | |
| try: | |
| yield db | |
| except (OperationalError, InterfaceError) as e: | |
| logger.error(f"数据库连接错误: {e}") | |
| db.rollback() | |
| raise | |
| finally: | |
| db.close() | |
| def check_db_connection() -> bool: | |
| """ | |
| 检查数据库连接是否正常 | |
| 用于健康检查接口 | |
| """ | |
| try: | |
| db = SessionLocal() | |
| db.execute(text("SELECT 1")) | |
| db.close() | |
| return True | |
| except Exception as e: | |
| logger.error(f"数据库连接检查失败: {e}") | |
| return False |