File size: 6,519 Bytes
85494ee
a0ab3de
 
 
85494ee
a0ab3de
 
 
 
85494ee
a0ab3de
 
85494ee
 
a0ab3de
 
66bebb6
85494ee
 
8f9d15a
 
 
 
a0ab3de
 
 
 
 
 
 
 
 
8f9d15a
a0ab3de
 
 
 
 
8f9d15a
 
 
 
 
a0ab3de
8f9d15a
85494ee
 
 
a0ab3de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85494ee
a0ab3de
 
 
 
c60e0ef
 
 
 
a0ab3de
 
 
 
 
 
 
 
 
 
85494ee
c60e0ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85494ee
a0ab3de
85494ee
 
 
a0ab3de
 
 
 
85494ee
a0ab3de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
# database_sql.py
# ==========================================
# 🛡️ 稳定性优化:SQL 数据库连接模块
# ==========================================
import os
import time
import logging
from functools import wraps
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import QueuePool, NullPool
from sqlalchemy.exc import OperationalError, InterfaceError, DBAPIError
from models_sql import Base

logger = logging.getLogger("ComfyUI-Ranking.Database")

# 核心:优先读取环境变量中的 PostgreSQL 数据库连接
SQLALCHEMY_DATABASE_URL = os.environ.get("DATABASE_URL", "sqlite:////tmp/comfy_financial.db")

# 🚀 P2优化:连接池配置支持环境变量
POOL_SIZE = int(os.environ.get("DB_POOL_SIZE", "5"))
POOL_OVERFLOW = int(os.environ.get("DB_POOL_OVERFLOW", "10"))

# 🚀 P1性能优化:根据数据库类型配置连接池
if "sqlite" in SQLALCHEMY_DATABASE_URL:
    # SQLite:使用空连接池,单线程模式
    connect_args = {"check_same_thread": False}
    engine = create_engine(
        SQLALCHEMY_DATABASE_URL, 
        connect_args=connect_args,
        poolclass=NullPool  # SQLite 不需要连接池
    )
    logger.info(f"数据库引擎初始化完成 (SQLite)")
else:
    # PostgreSQL/MySQL:配置连接池参数
    engine = create_engine(
        SQLALCHEMY_DATABASE_URL,
        poolclass=QueuePool,
        pool_size=POOL_SIZE,           # 核心连接数(支持环境变量配置)
        max_overflow=POOL_OVERFLOW,    # 超出 pool_size 后可创建的最大连接数(支持环境变量配置)
        pool_timeout=30,               # 获取连接超时(秒)
        pool_recycle=1800,             # 连接回收时间(30分钟),防止数据库断开
        pool_pre_ping=True             # 使用前检测连接有效性
    )
    logger.info(f"数据库引擎初始化完成 (PostgreSQL/MySQL) - 连接池配置: pool_size={POOL_SIZE}, max_overflow={POOL_OVERFLOW}")

SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


# ==========================================
# 🛡️ 稳定性优化:数据库操作重试装饰器
# ==========================================
def db_retry(max_retries: int = 3, delay: float = 0.5):
    """

    数据库操作重试装饰器

    

    用法:

        @db_retry(max_retries=3)

        def my_db_function(db):

            ...

    """
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            last_error = None
            for attempt in range(max_retries):
                try:
                    return func(*args, **kwargs)
                except (OperationalError, InterfaceError, DBAPIError) as e:
                    last_error = e
                    if attempt < max_retries - 1:
                        wait_time = delay * (2 ** attempt)  # 指数退避
                        logger.warning(f"DB 操作失败 (第{attempt+1}次):{e}, {wait_time}秒后重试")
                        time.sleep(wait_time)
                    else:
                        logger.error(f"DB 操作失败 (重试{max_retries}次后放弃): {e}")
            raise last_error
        return wrapper
    return decorator


def init_sql_db():
    """初始化数据库,包含重试机制"""
    for attempt in range(3):
        try:
            Base.metadata.create_all(bind=engine)
            
            # 🔄 P7后悔模式:自动迁移新增字段
            _auto_migrate_p7_fields()
            
            logger.info("数据库初始化成功")
            return
        except Exception as e:
            if attempt < 2:
                logger.warning(f"数据库初始化失败 (第{attempt+1}次): {e}")
                time.sleep(2 ** attempt)
            else:
                logger.error(f"数据库初始化失败 (已重试3次): {e}")
                raise


def _auto_migrate_p7_fields():
    """

    🔄 P7后悔模式:自动迁移新增字段

    检查并添加 ownerships 表的新字段

    """
    from sqlalchemy import inspect
    
    try:
        inspector = inspect(engine)
        
        # 检查 ownerships 表是否存在
        if 'ownerships' in inspector.get_table_names():
            columns = [col['name'] for col in inspector.get_columns('ownerships')]
            
            # 添加 P7 新增字段
            with engine.connect() as conn:
                if 'price_paid' not in columns:
                    if 'sqlite' in SQLALCHEMY_DATABASE_URL:
                        conn.execute(text("ALTER TABLE ownerships ADD COLUMN price_paid INTEGER DEFAULT 0"))
                    else:
                        conn.execute(text("ALTER TABLE ownerships ADD COLUMN price_paid INTEGER DEFAULT 0"))
                    logger.info("迁移完成: 添加 ownerships.price_paid 字段")
                
                if 'is_refunded' not in columns:
                    if 'sqlite' in SQLALCHEMY_DATABASE_URL:
                        conn.execute(text("ALTER TABLE ownerships ADD COLUMN is_refunded BOOLEAN DEFAULT 0"))
                    else:
                        conn.execute(text("ALTER TABLE ownerships ADD COLUMN is_refunded BOOLEAN DEFAULT FALSE"))
                    logger.info("迁移完成: 添加 ownerships.is_refunded 字段")
                
                if 'refunded_at' not in columns:
                    conn.execute(text("ALTER TABLE ownerships ADD COLUMN refunded_at TIMESTAMP"))
                    logger.info("迁移完成: 添加 ownerships.refunded_at 字段")
                
                conn.commit()
    except Exception as e:
        logger.warning(f"P7字段迁移跳过 (可能已存在): {e}")


def get_db():
    """获取数据库会话,带连接有效性检测"""
    db = SessionLocal()
    try:
        yield db
    except (OperationalError, InterfaceError) as e:
        logger.error(f"数据库连接错误: {e}")
        db.rollback()
        raise
    finally:
        db.close()


def check_db_connection() -> bool:
    """

    检查数据库连接是否正常

    用于健康检查接口

    """
    try:
        db = SessionLocal()
        db.execute(text("SELECT 1"))
        db.close()
        return True
    except Exception as e:
        logger.error(f"数据库连接检查失败: {e}")
        return False