ComfyUI-Ranking-API / database_sql.py
ZHIWEI666's picture
优化
8f9d15a verified
raw
history blame
6.52 kB
# database_sql.py
# ==========================================
# 🛡️ 稳定性优化:SQL 数据库连接模块
# ==========================================
import os
import time
import logging
from functools import wraps
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import QueuePool, NullPool
from sqlalchemy.exc import OperationalError, InterfaceError, DBAPIError
from models_sql import Base
logger = logging.getLogger("ComfyUI-Ranking.Database")
# 核心:优先读取环境变量中的 PostgreSQL 数据库连接
SQLALCHEMY_DATABASE_URL = os.environ.get("DATABASE_URL", "sqlite:////tmp/comfy_financial.db")
# 🚀 P2优化:连接池配置支持环境变量
POOL_SIZE = int(os.environ.get("DB_POOL_SIZE", "5"))
POOL_OVERFLOW = int(os.environ.get("DB_POOL_OVERFLOW", "10"))
# 🚀 P1性能优化:根据数据库类型配置连接池
if "sqlite" in SQLALCHEMY_DATABASE_URL:
# SQLite:使用空连接池,单线程模式
connect_args = {"check_same_thread": False}
engine = create_engine(
SQLALCHEMY_DATABASE_URL,
connect_args=connect_args,
poolclass=NullPool # SQLite 不需要连接池
)
logger.info(f"数据库引擎初始化完成 (SQLite)")
else:
# PostgreSQL/MySQL:配置连接池参数
engine = create_engine(
SQLALCHEMY_DATABASE_URL,
poolclass=QueuePool,
pool_size=POOL_SIZE, # 核心连接数(支持环境变量配置)
max_overflow=POOL_OVERFLOW, # 超出 pool_size 后可创建的最大连接数(支持环境变量配置)
pool_timeout=30, # 获取连接超时(秒)
pool_recycle=1800, # 连接回收时间(30分钟),防止数据库断开
pool_pre_ping=True # 使用前检测连接有效性
)
logger.info(f"数据库引擎初始化完成 (PostgreSQL/MySQL) - 连接池配置: pool_size={POOL_SIZE}, max_overflow={POOL_OVERFLOW}")
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# ==========================================
# 🛡️ 稳定性优化:数据库操作重试装饰器
# ==========================================
def db_retry(max_retries: int = 3, delay: float = 0.5):
"""
数据库操作重试装饰器
用法:
@db_retry(max_retries=3)
def my_db_function(db):
...
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
last_error = None
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except (OperationalError, InterfaceError, DBAPIError) as e:
last_error = e
if attempt < max_retries - 1:
wait_time = delay * (2 ** attempt) # 指数退避
logger.warning(f"DB 操作失败 (第{attempt+1}次):{e}, {wait_time}秒后重试")
time.sleep(wait_time)
else:
logger.error(f"DB 操作失败 (重试{max_retries}次后放弃): {e}")
raise last_error
return wrapper
return decorator
def init_sql_db():
"""初始化数据库,包含重试机制"""
for attempt in range(3):
try:
Base.metadata.create_all(bind=engine)
# 🔄 P7后悔模式:自动迁移新增字段
_auto_migrate_p7_fields()
logger.info("数据库初始化成功")
return
except Exception as e:
if attempt < 2:
logger.warning(f"数据库初始化失败 (第{attempt+1}次): {e}")
time.sleep(2 ** attempt)
else:
logger.error(f"数据库初始化失败 (已重试3次): {e}")
raise
def _auto_migrate_p7_fields():
"""
🔄 P7后悔模式:自动迁移新增字段
检查并添加 ownerships 表的新字段
"""
from sqlalchemy import inspect
try:
inspector = inspect(engine)
# 检查 ownerships 表是否存在
if 'ownerships' in inspector.get_table_names():
columns = [col['name'] for col in inspector.get_columns('ownerships')]
# 添加 P7 新增字段
with engine.connect() as conn:
if 'price_paid' not in columns:
if 'sqlite' in SQLALCHEMY_DATABASE_URL:
conn.execute(text("ALTER TABLE ownerships ADD COLUMN price_paid INTEGER DEFAULT 0"))
else:
conn.execute(text("ALTER TABLE ownerships ADD COLUMN price_paid INTEGER DEFAULT 0"))
logger.info("迁移完成: 添加 ownerships.price_paid 字段")
if 'is_refunded' not in columns:
if 'sqlite' in SQLALCHEMY_DATABASE_URL:
conn.execute(text("ALTER TABLE ownerships ADD COLUMN is_refunded BOOLEAN DEFAULT 0"))
else:
conn.execute(text("ALTER TABLE ownerships ADD COLUMN is_refunded BOOLEAN DEFAULT FALSE"))
logger.info("迁移完成: 添加 ownerships.is_refunded 字段")
if 'refunded_at' not in columns:
conn.execute(text("ALTER TABLE ownerships ADD COLUMN refunded_at TIMESTAMP"))
logger.info("迁移完成: 添加 ownerships.refunded_at 字段")
conn.commit()
except Exception as e:
logger.warning(f"P7字段迁移跳过 (可能已存在): {e}")
def get_db():
"""获取数据库会话,带连接有效性检测"""
db = SessionLocal()
try:
yield db
except (OperationalError, InterfaceError) as e:
logger.error(f"数据库连接错误: {e}")
db.rollback()
raise
finally:
db.close()
def check_db_connection() -> bool:
"""
检查数据库连接是否正常
用于健康检查接口
"""
try:
db = SessionLocal()
db.execute(text("SELECT 1"))
db.close()
return True
except Exception as e:
logger.error(f"数据库连接检查失败: {e}")
return False