File size: 8,442 Bytes
7c15d35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
from sqlalchemy import create_engine, inspect, text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from app.config import settings

engine = create_engine(
    settings.DATABASE_URL,
    connect_args={"check_same_thread": False} if "sqlite" in settings.DATABASE_URL else {}
)

SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

Base = declarative_base()


def get_db():
    """数据库会话依赖"""
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()


def init_db():
    """初始化数据库 - 安全地创建或更新数据库结构"""
    try:
        # 导入所有模型以确保它们被注册到 Base.metadata
        from app.models import models  # noqa: F401
        
        # 创建所有表(如果不存在)
        Base.metadata.create_all(bind=engine)
        
        # 检查并添加可能缺失的列(用于数据库迁移)
        _migrate_database_schema()
        
        # 自动添加性能优化索引
        _add_performance_indexes()
        
        print("✓ 数据库初始化成功")
        return True
    except Exception as e:
        print(f"✗ 数据库初始化失败: {str(e)}")
        raise


def _add_column_safely(conn, table_name, column_name, column_def):
    """安全地添加列(如果不存在)"""
    try:
        conn.execute(text(f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_def}"))
        conn.commit()
        return True
    except Exception as e:
        # 列可能已存在或其他错误
        conn.rollback()
        return False


def _add_performance_indexes():
    """添加性能优化索引"""
    try:
        inspector = inspect(engine)
        tables = inspector.get_table_names()
        
        # 定义需要的索引
        indexes = [
            # OptimizationSession indexes
            ("idx_opt_session_user_id", "optimization_sessions", "user_id"),
            ("idx_opt_session_status", "optimization_sessions", "status"),
            ("idx_opt_session_created_at", "optimization_sessions", "created_at"),
            
            # OptimizationSegment indexes
            ("idx_opt_segment_session_id", "optimization_segments", "session_id"),
            ("idx_opt_segment_index", "optimization_segments", "segment_index"),
            ("idx_opt_segment_status", "optimization_segments", "status"),
            
            # ChangeLog indexes
            ("idx_change_log_session_id", "change_logs", "session_id"),
            ("idx_change_log_segment_index", "change_logs", "segment_index"),
            ("idx_change_log_stage", "change_logs", "stage"),
        ]
        
        with engine.connect() as conn:
            for index_name, table_name, column_name in indexes:
                # 检查表是否存在
                if table_name not in tables:
                    continue
                
                try:
                    # 获取表上现有的索引
                    existing_indexes = inspector.get_indexes(table_name)
                    index_names = {idx['name'] for idx in existing_indexes}
                    
                    # 如果索引已存在,跳过
                    if index_name in index_names:
                        continue
                    
                    # 创建索引(SQLite 和 PostgreSQL 都支持相同语法)
                    conn.execute(text(
                        f"CREATE INDEX IF NOT EXISTS {index_name} ON {table_name} ({column_name})"
                    ))
                    conn.commit()
                    print(f"  ✓ 添加索引: {index_name}")
                    
                except Exception as e:
                    # 索引可能已存在或其他错误
                    conn.rollback()
                    # 静默失败,不阻止应用启动
                    pass
    
    except Exception as e:
        print(f"  ⚠ 添加性能索引警告: {str(e)}")
        # 失败不应该阻止应用启动


def _migrate_database_schema():
    """迁移数据库结构 - 添加新列到已存在的表"""
    try:
        inspector = inspect(engine)
        
        # 检查表是否存在
        tables = inspector.get_table_names()
        
        with engine.connect() as conn:
            
                # 迁移 optimization_sessions 表
                if "optimization_sessions" in tables:
                    columns = {column["name"] for column in inspector.get_columns("optimization_sessions")}
                    
                    if "failed_segment_index" not in columns:
                        if _add_column_safely(conn, "optimization_sessions", "failed_segment_index", "INTEGER"):
                            print("  ✓ 添加字段: optimization_sessions.failed_segment_index")
                    
                    if "processing_mode" not in columns:
                        if _add_column_safely(conn, "optimization_sessions", "processing_mode", "VARCHAR(50) DEFAULT 'paper_polish_enhance'"):
                            print("  ✓ 添加字段: optimization_sessions.processing_mode")
                    
                    if "emotion_model" not in columns:
                        added = _add_column_safely(conn, "optimization_sessions", "emotion_model", "VARCHAR(100)")
                        _add_column_safely(conn, "optimization_sessions", "emotion_api_key", "VARCHAR(255)")
                        _add_column_safely(conn, "optimization_sessions", "emotion_base_url", "VARCHAR(255)")
                        if added:
                            print("  ✓ 添加字段: optimization_sessions.emotion_* 字段")
            
                # 迁移 users 表
                if "users" in tables:
                    user_columns = {column["name"] for column in inspector.get_columns("users")}
                    
                    if "usage_limit" not in user_columns:
                        if _add_column_safely(conn, "users", "usage_limit", f"INTEGER DEFAULT {settings.DEFAULT_USAGE_LIMIT}"):
                            print("  ✓ 添加字段: users.usage_limit")
                    
                    if "usage_count" not in user_columns:
                        if _add_column_safely(conn, "users", "usage_count", "INTEGER DEFAULT 0"):
                            print("  ✓ 添加字段: users.usage_count")
                    
                    # 更新 NULL 值
                    try:
                        conn.execute(text(f"UPDATE users SET usage_limit = {settings.DEFAULT_USAGE_LIMIT} WHERE usage_limit IS NULL"))
                        conn.execute(text("UPDATE users SET usage_count = 0 WHERE usage_count IS NULL"))
                        conn.commit()
                    except Exception:
                        conn.rollback()
            
                # 迁移 optimization_segments 表
                if "optimization_segments" in tables:
                    segment_columns = {column["name"] for column in inspector.get_columns("optimization_segments")}
                    
                    if "is_title" not in segment_columns:
                        if _add_column_safely(conn, "optimization_segments", "is_title", "BOOLEAN DEFAULT 0"):
                            print("  ✓ 添加字段: optimization_segments.is_title")
            
                # 迁移 custom_prompts 表
                if "custom_prompts" in tables:
                    prompt_columns = {column["name"] for column in inspector.get_columns("custom_prompts")}
                    
                    if "is_system" not in prompt_columns:
                        if _add_column_safely(conn, "custom_prompts", "is_system", "BOOLEAN DEFAULT 0"):
                            print("  ✓ 添加字段: custom_prompts.is_system")
                    
                    if "is_active" not in prompt_columns:
                        if _add_column_safely(conn, "custom_prompts", "is_active", "BOOLEAN DEFAULT 1"):
                            print("  ✓ 添加字段: custom_prompts.is_active")
    
    except Exception as e:
        print(f"  ⚠ 数据库迁移警告: {str(e)}")
        # 迁移失败不应该阻止应用启动