ComfyUI-Ranking-API / database_sql.py
ZHIWEI666's picture
Upload 4 files
3c40bd1 verified
# database_sql.py
# ==========================================
# 🛡️ 稳定性优化:SQL 数据库连接模块
# ==========================================
import os
import time
import logging
from functools import wraps
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import QueuePool, NullPool
from sqlalchemy.exc import OperationalError, InterfaceError, DBAPIError
from models_sql import Base
logger = logging.getLogger("ComfyUI-Ranking.Database")
# 核心:优先读取环境变量中的 PostgreSQL 数据库连接
SQLALCHEMY_DATABASE_URL = os.environ.get("DATABASE_URL", "sqlite:////tmp/comfy_financial.db")
# 🚀 P2优化:连接池配置支持环境变量
POOL_SIZE = int(os.environ.get("DB_POOL_SIZE", "5"))
POOL_OVERFLOW = int(os.environ.get("DB_POOL_OVERFLOW", "10"))
# 🚀 P1性能优化:根据数据库类型配置连接池
if "sqlite" in SQLALCHEMY_DATABASE_URL:
# SQLite:使用空连接池,单线程模式
connect_args = {"check_same_thread": False}
engine = create_engine(
SQLALCHEMY_DATABASE_URL,
connect_args=connect_args,
poolclass=NullPool # SQLite 不需要连接池
)
logger.info(f"数据库引擎初始化完成 (SQLite)")
else:
# PostgreSQL/MySQL:配置连接池参数
connect_args = {}
if "postgresql" in SQLALCHEMY_DATABASE_URL or "postgres" in SQLALCHEMY_DATABASE_URL:
sslmode = os.environ.get("DB_SSLMODE", "require")
connect_args["sslmode"] = sslmode
connect_args["connect_timeout"] = int(os.environ.get("DB_CONNECT_TIMEOUT", "10"))
logger.info(f"PostgreSQL SSL 配置: sslmode={sslmode}, connect_timeout={connect_args['connect_timeout']}")
engine = create_engine(
SQLALCHEMY_DATABASE_URL,
poolclass=QueuePool,
pool_size=POOL_SIZE, # 核心连接数(支持环境变量配置)
max_overflow=POOL_OVERFLOW, # 超出 pool_size 后可创建的最大连接数(支持环境变量配置)
pool_timeout=30, # 获取连接超时(秒)
pool_recycle=1800, # 连接回收时间(30分钟)
pool_pre_ping=True, # 使用前检测连接有效性
connect_args=connect_args
)
logger.info(f"数据库引擎初始化完成 (PostgreSQL/MySQL) - 连接池配置: pool_size={POOL_SIZE}, max_overflow={POOL_OVERFLOW}")
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# ==========================================
# 🛡️ 稳定性优化:数据库操作重试装饰器
# ==========================================
def db_retry(max_retries: int = 3, delay: float = 0.5):
"""
数据库操作重试装饰器
用法:
@db_retry(max_retries=3)
def my_db_function(db):
...
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
last_error = None
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except (OperationalError, InterfaceError, DBAPIError) as e:
last_error = e
if attempt < max_retries - 1:
wait_time = delay * (2 ** attempt) # 指数退避
logger.warning(f"DB 操作失败 (第{attempt+1}次):{e}, {wait_time}秒后重试")
time.sleep(wait_time)
else:
logger.error(f"DB 操作失败 (重试{max_retries}次后放弃): {e}")
raise last_error
return wrapper
return decorator
def init_sql_db():
"""初始化数据库,包含重试机制"""
for attempt in range(3):
try:
Base.metadata.create_all(bind=engine)
# 🔄 P7后悔模式:自动迁移新增字段
_auto_migrate_p7_fields()
logger.info("数据库初始化成功")
return
except Exception as e:
if attempt < 2:
logger.warning(f"数据库初始化失败 (第{attempt+1}次): {e}")
time.sleep(2 ** attempt)
else:
logger.error(f"数据库初始化失败 (已重试3次): {e}")
raise
def _auto_migrate_p7_fields():
"""
🔄 P7后悔模式:自动迁移新增字段
检查并添加 ownerships 表和 transactions 表的新字段
"""
from sqlalchemy import inspect
try:
inspector = inspect(engine)
# ========== 1. 迁移 ownerships 表 ==========
if 'ownerships' in inspector.get_table_names():
columns = [col['name'] for col in inspector.get_columns('ownerships')]
with engine.connect() as conn:
if 'price_paid' not in columns:
if 'sqlite' in SQLALCHEMY_DATABASE_URL:
conn.execute(text("ALTER TABLE ownerships ADD COLUMN price_paid INTEGER DEFAULT 0"))
else:
conn.execute(text("ALTER TABLE ownerships ADD COLUMN price_paid INTEGER DEFAULT 0"))
logger.info("迁移完成: 添加 ownerships.price_paid 字段")
if 'is_refunded' not in columns:
if 'sqlite' in SQLALCHEMY_DATABASE_URL:
conn.execute(text("ALTER TABLE ownerships ADD COLUMN is_refunded BOOLEAN DEFAULT 0"))
else:
conn.execute(text("ALTER TABLE ownerships ADD COLUMN is_refunded BOOLEAN DEFAULT FALSE"))
logger.info("迁移完成: 添加 ownerships.is_refunded 字段")
if 'refunded_at' not in columns:
conn.execute(text("ALTER TABLE ownerships ADD COLUMN refunded_at TIMESTAMP"))
logger.info("迁移完成: 添加 ownerships.refunded_at 字段")
conn.commit()
# ========== 2. 迁移 wallets 表(新增 task_balance 字段) ==========
if 'wallets' in inspector.get_table_names():
columns = [col['name'] for col in inspector.get_columns('wallets')]
with engine.connect() as conn:
if 'task_balance' not in columns:
if 'sqlite' in SQLALCHEMY_DATABASE_URL:
conn.execute(text("ALTER TABLE wallets ADD COLUMN task_balance INTEGER DEFAULT 0"))
else:
conn.execute(text("ALTER TABLE wallets ADD COLUMN task_balance INTEGER DEFAULT 0"))
logger.info("[DB Migration] 添加列 wallets.task_balance")
conn.commit()
# ========== 3. 迁移 transactions 表(提现相关新字段) ==========
if 'transactions' in inspector.get_table_names():
columns = [col['name'] for col in inspector.get_columns('transactions')]
with engine.connect() as conn:
# 定义 transactions 表的新列
new_columns = {
'alipay_account': 'VARCHAR',
'real_name': 'VARCHAR',
'withdraw_status': 'VARCHAR',
'payment_order_id': 'VARCHAR',
'net_amount': 'INTEGER',
'description': 'VARCHAR',
'item_title': 'VARCHAR',
'item_type': 'VARCHAR',
'related_user_name': 'VARCHAR',
}
for col_name, col_type in new_columns.items():
if col_name not in columns:
# VARCHAR 类型添加 DEFAULT NULL 避免 PostgreSQL NOT NULL 冲突
if col_type == 'VARCHAR':
conn.execute(text(f"ALTER TABLE transactions ADD COLUMN {col_name} {col_type} DEFAULT NULL"))
else:
conn.execute(text(f"ALTER TABLE transactions ADD COLUMN {col_name} {col_type}"))
logger.info(f"[DB Migration] 添加列 transactions.{col_name}")
conn.commit()
# ========== 4. 回填旧提现记录数据 ==========
# 将旧 WITHDRAW 记录的 withdraw_status 从 NULL 更新为 "completed"
result = conn.execute(text(
"UPDATE transactions SET withdraw_status = 'completed' "
"WHERE tx_type = 'WITHDRAW' AND withdraw_status IS NULL"
))
if result.rowcount > 0:
logger.info(f"[DB Migration] 回填旧提现记录: 更新 {result.rowcount} 条记录的 withdraw_status 为 'completed'")
# 将旧记录的 net_amount 从 NULL 更新为 ABS(amount)
result = conn.execute(text(
"UPDATE transactions SET net_amount = ABS(amount) "
"WHERE tx_type = 'WITHDRAW' AND net_amount IS NULL"
))
if result.rowcount > 0:
logger.info(f"[DB Migration] 回填旧提现记录: 更新 {result.rowcount} 条记录的 net_amount 为 ABS(amount)")
conn.commit()
except Exception as e:
logger.warning(f"字段迁移跳过 (可能已存在): {e}")
def get_db():
"""获取数据库会话,带连接有效性检测"""
db = SessionLocal()
try:
yield db
except (OperationalError, InterfaceError) as e:
logger.error(f"数据库连接错误: {e}")
db.rollback()
raise
finally:
db.close()
def check_db_connection() -> bool:
"""
检查数据库连接是否正常
用于健康检查接口
"""
try:
db = SessionLocal()
db.execute(text("SELECT 1"))
db.close()
return True
except Exception as e:
logger.error(f"数据库连接检查失败: {e}")
return False