Spaces:
Running
Running
| # database_sql.py | |
| # ========================================== | |
| # 🛡️ 稳定性优化:SQL 数据库连接模块 | |
| # ========================================== | |
| import os | |
| import time | |
| import logging | |
| from functools import wraps | |
| from sqlalchemy import create_engine, text | |
| from sqlalchemy.orm import sessionmaker | |
| from sqlalchemy.pool import QueuePool, NullPool | |
| from sqlalchemy.exc import OperationalError, InterfaceError, DBAPIError | |
| from models_sql import Base | |
| logger = logging.getLogger("ComfyUI-Ranking.Database") | |
| # 核心:优先读取环境变量中的 PostgreSQL 数据库连接 | |
| SQLALCHEMY_DATABASE_URL = os.environ.get("DATABASE_URL", "sqlite:////tmp/comfy_financial.db") | |
| # 🚀 P1性能优化:根据数据库类型配置连接池 | |
| if "sqlite" in SQLALCHEMY_DATABASE_URL: | |
| # SQLite:使用空连接池,单线程模式 | |
| connect_args = {"check_same_thread": False} | |
| engine = create_engine( | |
| SQLALCHEMY_DATABASE_URL, | |
| connect_args=connect_args, | |
| poolclass=NullPool # SQLite 不需要连接池 | |
| ) | |
| else: | |
| # PostgreSQL/MySQL:配置连接池参数 | |
| engine = create_engine( | |
| SQLALCHEMY_DATABASE_URL, | |
| poolclass=QueuePool, | |
| pool_size=5, # 核心连接数 | |
| max_overflow=10, # 超出 pool_size 后可创建的最大连接数 | |
| pool_timeout=30, # 获取连接超时(秒) | |
| pool_recycle=1800, # 连接回收时间(30分钟),防止数据库断开 | |
| pool_pre_ping=True # 使用前检测连接有效性 | |
| ) | |
| SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) | |
| # ========================================== | |
| # 🛡️ 稳定性优化:数据库操作重试装饰器 | |
| # ========================================== | |
| def db_retry(max_retries: int = 3, delay: float = 0.5): | |
| """ | |
| 数据库操作重试装饰器 | |
| 用法: | |
| @db_retry(max_retries=3) | |
| def my_db_function(db): | |
| ... | |
| """ | |
| def decorator(func): | |
| def wrapper(*args, **kwargs): | |
| last_error = None | |
| for attempt in range(max_retries): | |
| try: | |
| return func(*args, **kwargs) | |
| except (OperationalError, InterfaceError, DBAPIError) as e: | |
| last_error = e | |
| if attempt < max_retries - 1: | |
| wait_time = delay * (2 ** attempt) # 指数退避 | |
| logger.warning(f"DB 操作失败 (第{attempt+1}次):{e}, {wait_time}秒后重试") | |
| time.sleep(wait_time) | |
| else: | |
| logger.error(f"DB 操作失败 (重试{max_retries}次后放弃): {e}") | |
| raise last_error | |
| return wrapper | |
| return decorator | |
| def init_sql_db(): | |
| """初始化数据库,包含重试机制""" | |
| for attempt in range(3): | |
| try: | |
| Base.metadata.create_all(bind=engine) | |
| logger.info("数据库初始化成功") | |
| return | |
| except Exception as e: | |
| if attempt < 2: | |
| logger.warning(f"数据库初始化失败 (第{attempt+1}次): {e}") | |
| time.sleep(2 ** attempt) | |
| else: | |
| logger.error(f"数据库初始化失败 (已重试3次): {e}") | |
| raise | |
| def get_db(): | |
| """获取数据库会话,带连接有效性检测""" | |
| db = SessionLocal() | |
| try: | |
| yield db | |
| except (OperationalError, InterfaceError) as e: | |
| logger.error(f"数据库连接错误: {e}") | |
| db.rollback() | |
| raise | |
| finally: | |
| db.close() | |
| def check_db_connection() -> bool: | |
| """ | |
| 检查数据库连接是否正常 | |
| 用于健康检查接口 | |
| """ | |
| try: | |
| db = SessionLocal() | |
| db.execute(text("SELECT 1")) | |
| db.close() | |
| return True | |
| except Exception as e: | |
| logger.error(f"数据库连接检查失败: {e}") | |
| return False |