File size: 6,604 Bytes
7c15d35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
#!/usr/bin/env python3
"""
数据库初始化和健康检查脚本
可以独立运行以测试数据库连接和初始化
"""

import sys
import os
from pathlib import Path

# 添加项目路径到 Python 路径
backend_dir = Path(__file__).parent
sys.path.insert(0, str(backend_dir))

from app.database import init_db, engine, SessionLocal
from app.models.models import User, OptimizationSession, CustomPrompt, SystemSetting
from sqlalchemy import text, inspect


def check_database_connection():
    """检查数据库连接"""
    print("检查数据库连接...")
    try:
        with engine.connect() as conn:
            result = conn.execute(text("SELECT 1"))
            result.fetchone()
        print("✓ 数据库连接成功")
        return True
    except Exception as e:
        print(f"✗ 数据库连接失败: {str(e)}")
        return False


def check_tables():
    """检查数据库表"""
    print("\n检查数据库表...")
    try:
        inspector = inspect(engine)
        tables = inspector.get_table_names()
        
        expected_tables = [
            "users",
            "optimization_sessions",
            "optimization_segments",
            "session_history",
            "change_logs",
            "queue_status",
            "system_settings",
            "custom_prompts"
        ]
        
        missing_tables = [t for t in expected_tables if t not in tables]
        
        if missing_tables:
            print(f"⚠ 缺少以下表: {', '.join(missing_tables)}")
            return False
        else:
            print(f"✓ 所有必需的表都存在 ({len(expected_tables)} 个)")
            return True
    except Exception as e:
        print(f"✗ 检查表失败: {str(e)}")
        return False


def display_table_info():
    """显示表信息"""
    print("\n数据库表信息:")
    print("-" * 60)
    try:
        inspector = inspect(engine)
        tables = inspector.get_table_names()
        
        for table_name in sorted(tables):
            columns = inspector.get_columns(table_name)
            print(f"\n📊 {table_name} ({len(columns)} 列)")
            for col in columns[:5]:  # 只显示前5列
                col_type = str(col['type'])
                nullable = "NULL" if col['nullable'] else "NOT NULL"
                print(f"   - {col['name']}: {col_type} {nullable}")
            if len(columns) > 5:
                print(f"   ... 还有 {len(columns) - 5} 列")
    except Exception as e:
        print(f"✗ 获取表信息失败: {str(e)}")


def check_data_integrity():
    """检查数据完整性"""
    print("\n检查数据完整性...")
    try:
        db = SessionLocal()
        try:
            # 检查用户数量
            user_count = db.query(User).count()
            print(f"✓ 用户数量: {user_count}")
            
            # 检查会话数量
            session_count = db.query(OptimizationSession).count()
            print(f"✓ 会话数量: {session_count}")
            
            # 检查系统提示词
            system_prompts = db.query(CustomPrompt).filter(CustomPrompt.is_system == True).count()
            print(f"✓ 系统提示词数量: {system_prompts}")
            
            return True
        finally:
            db.close()
    except Exception as e:
        print(f"✗ 数据完整性检查失败: {str(e)}")
        return False


def test_crud_operations():
    """测试基本的 CRUD 操作"""
    print("\n测试数据库操作...")
    try:
        db = SessionLocal()
        try:
            # 测试创建
            test_setting = SystemSetting(
                key="test_key_delete_me",
                value="test_value"
            )
            db.add(test_setting)
            db.commit()
            print("✓ CREATE 操作成功")
            
            # 测试读取
            setting = db.query(SystemSetting).filter(
                SystemSetting.key == "test_key_delete_me"
            ).first()
            if setting:
                print("✓ READ 操作成功")
            
            # 测试更新
            setting.value = "updated_value"
            db.commit()
            print("✓ UPDATE 操作成功")
            
            # 测试删除
            db.delete(setting)
            db.commit()
            print("✓ DELETE 操作成功")
            
            return True
        finally:
            db.close()
    except Exception as e:
        print(f"✗ CRUD 操作测试失败: {str(e)}")
        return False


def main():
    """主函数"""
    print("=" * 60)
    print("数据库初始化和健康检查")
    print("=" * 60)
    
    # 检查环境变量
    env_file = backend_dir / ".env"
    if not env_file.exists():
        print(f"\n⚠ 警告: 未找到 .env 文件")
        print(f"   预期位置: {env_file}")
        print("   将使用默认配置\n")
    
    # 1. 检查数据库连接
    if not check_database_connection():
        print("\n❌ 数据库连接失败,无法继续")
        sys.exit(1)
    
    # 2. 初始化数据库
    print("\n" + "=" * 60)
    print("初始化数据库...")
    print("=" * 60)
    try:
        init_db()
    except Exception as e:
        print(f"\n❌ 数据库初始化失败: {str(e)}")
        import traceback
        traceback.print_exc()
        sys.exit(1)
    
    # 3. 检查表
    if not check_tables():
        print("\n⚠ 警告: 某些表缺失")
    
    # 4. 显示表信息
    display_table_info()
    
    # 5. 检查数据完整性
    check_data_integrity()
    
    # 6. 测试 CRUD 操作
    test_crud_operations()
    
    # 总结
    print("\n" + "=" * 60)
    print("✓ 数据库检查完成!")
    print("=" * 60)
    print("\n数据库已就绪,可以启动应用")
    
    # 显示数据库文件位置
    from app.config import settings
    if "sqlite" in settings.DATABASE_URL:
        db_path = settings.DATABASE_URL.replace("sqlite:///", "")
        if db_path.startswith("./"):
            db_path = backend_dir / db_path[2:]
        else:
            db_path = Path(db_path)
        
        if db_path.exists():
            size_mb = db_path.stat().st_size / (1024 * 1024)
            print(f"\n📁 数据库文件: {db_path}")
            print(f"   大小: {size_mb:.2f} MB")


if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\n\n⚠ 用户中断")
        sys.exit(0)
    except Exception as e:
        print(f"\n❌ 发生错误: {str(e)}")
        import traceback
        traceback.print_exc()
        sys.exit(1)