File size: 10,386 Bytes
85494ee
a0ab3de
 
 
85494ee
a0ab3de
 
 
 
85494ee
a0ab3de
 
85494ee
 
a0ab3de
 
66bebb6
85494ee
 
8f9d15a
 
 
 
a0ab3de
 
 
 
 
 
 
 
 
8f9d15a
a0ab3de
 
820f705
 
57fd107
 
 
 
a0ab3de
 
 
8f9d15a
 
 
820f705
 
 
a0ab3de
8f9d15a
85494ee
 
 
a0ab3de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85494ee
a0ab3de
 
 
 
c60e0ef
 
 
 
a0ab3de
 
 
 
 
 
 
 
 
 
85494ee
c60e0ef
 
 
c71735f
c60e0ef
 
 
 
 
 
c71735f
c60e0ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c71735f
3c40bd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c71735f
 
 
 
 
 
 
 
 
 
 
e4b11fa
 
 
 
c71735f
 
 
 
f01b4f8
 
 
 
 
c71735f
 
 
 
3c40bd1
d6cee16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c60e0ef
c71735f
c60e0ef
 
85494ee
a0ab3de
85494ee
 
 
a0ab3de
 
 
 
85494ee
a0ab3de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
# database_sql.py
# ==========================================
# 🛡️ 稳定性优化:SQL 数据库连接模块
# ==========================================
import os
import time
import logging
from functools import wraps
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import QueuePool, NullPool
from sqlalchemy.exc import OperationalError, InterfaceError, DBAPIError
from models_sql import Base

logger = logging.getLogger("ComfyUI-Ranking.Database")

# 核心:优先读取环境变量中的 PostgreSQL 数据库连接
SQLALCHEMY_DATABASE_URL = os.environ.get("DATABASE_URL", "sqlite:////tmp/comfy_financial.db")

# 🚀 P2优化:连接池配置支持环境变量
POOL_SIZE = int(os.environ.get("DB_POOL_SIZE", "5"))
POOL_OVERFLOW = int(os.environ.get("DB_POOL_OVERFLOW", "10"))

# 🚀 P1性能优化:根据数据库类型配置连接池
if "sqlite" in SQLALCHEMY_DATABASE_URL:
    # SQLite:使用空连接池,单线程模式
    connect_args = {"check_same_thread": False}
    engine = create_engine(
        SQLALCHEMY_DATABASE_URL, 
        connect_args=connect_args,
        poolclass=NullPool  # SQLite 不需要连接池
    )
    logger.info(f"数据库引擎初始化完成 (SQLite)")
else:
    # PostgreSQL/MySQL:配置连接池参数
    connect_args = {}
    if "postgresql" in SQLALCHEMY_DATABASE_URL or "postgres" in SQLALCHEMY_DATABASE_URL:
        sslmode = os.environ.get("DB_SSLMODE", "require")
        connect_args["sslmode"] = sslmode
        connect_args["connect_timeout"] = int(os.environ.get("DB_CONNECT_TIMEOUT", "10"))
        logger.info(f"PostgreSQL SSL 配置: sslmode={sslmode}, connect_timeout={connect_args['connect_timeout']}")
    engine = create_engine(
        SQLALCHEMY_DATABASE_URL,
        poolclass=QueuePool,
        pool_size=POOL_SIZE,           # 核心连接数(支持环境变量配置)
        max_overflow=POOL_OVERFLOW,    # 超出 pool_size 后可创建的最大连接数(支持环境变量配置)
        pool_timeout=30,               # 获取连接超时(秒)
        pool_recycle=1800,             # 连接回收时间(30分钟)
        pool_pre_ping=True,            # 使用前检测连接有效性
        connect_args=connect_args
    )
    logger.info(f"数据库引擎初始化完成 (PostgreSQL/MySQL) - 连接池配置: pool_size={POOL_SIZE}, max_overflow={POOL_OVERFLOW}")

SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


# ==========================================
# 🛡️ 稳定性优化:数据库操作重试装饰器
# ==========================================
def db_retry(max_retries: int = 3, delay: float = 0.5):
    """

    数据库操作重试装饰器

    

    用法:

        @db_retry(max_retries=3)

        def my_db_function(db):

            ...

    """
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            last_error = None
            for attempt in range(max_retries):
                try:
                    return func(*args, **kwargs)
                except (OperationalError, InterfaceError, DBAPIError) as e:
                    last_error = e
                    if attempt < max_retries - 1:
                        wait_time = delay * (2 ** attempt)  # 指数退避
                        logger.warning(f"DB 操作失败 (第{attempt+1}次):{e}, {wait_time}秒后重试")
                        time.sleep(wait_time)
                    else:
                        logger.error(f"DB 操作失败 (重试{max_retries}次后放弃): {e}")
            raise last_error
        return wrapper
    return decorator


def init_sql_db():
    """初始化数据库,包含重试机制"""
    for attempt in range(3):
        try:
            Base.metadata.create_all(bind=engine)
            
            # 🔄 P7后悔模式:自动迁移新增字段
            _auto_migrate_p7_fields()
            
            logger.info("数据库初始化成功")
            return
        except Exception as e:
            if attempt < 2:
                logger.warning(f"数据库初始化失败 (第{attempt+1}次): {e}")
                time.sleep(2 ** attempt)
            else:
                logger.error(f"数据库初始化失败 (已重试3次): {e}")
                raise


def _auto_migrate_p7_fields():
    """

    🔄 P7后悔模式:自动迁移新增字段

    检查并添加 ownerships 表和 transactions 表的新字段

    """
    from sqlalchemy import inspect
    
    try:
        inspector = inspect(engine)
        
        # ========== 1. 迁移 ownerships 表 ==========
        if 'ownerships' in inspector.get_table_names():
            columns = [col['name'] for col in inspector.get_columns('ownerships')]
            
            with engine.connect() as conn:
                if 'price_paid' not in columns:
                    if 'sqlite' in SQLALCHEMY_DATABASE_URL:
                        conn.execute(text("ALTER TABLE ownerships ADD COLUMN price_paid INTEGER DEFAULT 0"))
                    else:
                        conn.execute(text("ALTER TABLE ownerships ADD COLUMN price_paid INTEGER DEFAULT 0"))
                    logger.info("迁移完成: 添加 ownerships.price_paid 字段")
                
                if 'is_refunded' not in columns:
                    if 'sqlite' in SQLALCHEMY_DATABASE_URL:
                        conn.execute(text("ALTER TABLE ownerships ADD COLUMN is_refunded BOOLEAN DEFAULT 0"))
                    else:
                        conn.execute(text("ALTER TABLE ownerships ADD COLUMN is_refunded BOOLEAN DEFAULT FALSE"))
                    logger.info("迁移完成: 添加 ownerships.is_refunded 字段")
                
                if 'refunded_at' not in columns:
                    conn.execute(text("ALTER TABLE ownerships ADD COLUMN refunded_at TIMESTAMP"))
                    logger.info("迁移完成: 添加 ownerships.refunded_at 字段")
                
                conn.commit()
        
        # ========== 2. 迁移 wallets 表(新增 task_balance 字段) ==========
        if 'wallets' in inspector.get_table_names():
            columns = [col['name'] for col in inspector.get_columns('wallets')]
            
            with engine.connect() as conn:
                if 'task_balance' not in columns:
                    if 'sqlite' in SQLALCHEMY_DATABASE_URL:
                        conn.execute(text("ALTER TABLE wallets ADD COLUMN task_balance INTEGER DEFAULT 0"))
                    else:
                        conn.execute(text("ALTER TABLE wallets ADD COLUMN task_balance INTEGER DEFAULT 0"))
                    logger.info("[DB Migration] 添加列 wallets.task_balance")
                
                conn.commit()
        
        # ========== 3. 迁移 transactions 表(提现相关新字段) ==========
        if 'transactions' in inspector.get_table_names():
            columns = [col['name'] for col in inspector.get_columns('transactions')]
            
            with engine.connect() as conn:
                # 定义 transactions 表的新列
                new_columns = {
                    'alipay_account': 'VARCHAR',
                    'real_name': 'VARCHAR',
                    'withdraw_status': 'VARCHAR',
                    'payment_order_id': 'VARCHAR',
                    'net_amount': 'INTEGER',
                    'description': 'VARCHAR',
                    'item_title': 'VARCHAR',
                    'item_type': 'VARCHAR',
                    'related_user_name': 'VARCHAR',
                }
                
                for col_name, col_type in new_columns.items():
                    if col_name not in columns:
                        # VARCHAR 类型添加 DEFAULT NULL 避免 PostgreSQL NOT NULL 冲突
                        if col_type == 'VARCHAR':
                            conn.execute(text(f"ALTER TABLE transactions ADD COLUMN {col_name} {col_type} DEFAULT NULL"))
                        else:
                            conn.execute(text(f"ALTER TABLE transactions ADD COLUMN {col_name} {col_type}"))
                        logger.info(f"[DB Migration] 添加列 transactions.{col_name}")
                
                conn.commit()
                
                # ========== 4. 回填旧提现记录数据 ==========
                # 将旧 WITHDRAW 记录的 withdraw_status 从 NULL 更新为 "completed"
                result = conn.execute(text(
                    "UPDATE transactions SET withdraw_status = 'completed' "
                    "WHERE tx_type = 'WITHDRAW' AND withdraw_status IS NULL"
                ))
                if result.rowcount > 0:
                    logger.info(f"[DB Migration] 回填旧提现记录: 更新 {result.rowcount} 条记录的 withdraw_status 为 'completed'")
                
                # 将旧记录的 net_amount 从 NULL 更新为 ABS(amount)
                result = conn.execute(text(
                    "UPDATE transactions SET net_amount = ABS(amount) "
                    "WHERE tx_type = 'WITHDRAW' AND net_amount IS NULL"
                ))
                if result.rowcount > 0:
                    logger.info(f"[DB Migration] 回填旧提现记录: 更新 {result.rowcount} 条记录的 net_amount 为 ABS(amount)")
                
                conn.commit()
                
    except Exception as e:
        logger.warning(f"字段迁移跳过 (可能已存在): {e}")


def get_db():
    """获取数据库会话,带连接有效性检测"""
    db = SessionLocal()
    try:
        yield db
    except (OperationalError, InterfaceError) as e:
        logger.error(f"数据库连接错误: {e}")
        db.rollback()
        raise
    finally:
        db.close()


def check_db_connection() -> bool:
    """

    检查数据库连接是否正常

    用于健康检查接口

    """
    try:
        db = SessionLocal()
        db.execute(text("SELECT 1"))
        db.close()
        return True
    except Exception as e:
        logger.error(f"数据库连接检查失败: {e}")
        return False