Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| 数据库初始化和健康检查脚本 | |
| 可以独立运行以测试数据库连接和初始化 | |
| """ | |
| import sys | |
| import os | |
| from pathlib import Path | |
| # 添加项目路径到 Python 路径 | |
| backend_dir = Path(__file__).parent | |
| sys.path.insert(0, str(backend_dir)) | |
| from app.database import init_db, engine, SessionLocal | |
| from app.models.models import User, OptimizationSession, CustomPrompt, SystemSetting | |
| from sqlalchemy import text, inspect | |
| def check_database_connection(): | |
| """检查数据库连接""" | |
| print("检查数据库连接...") | |
| try: | |
| with engine.connect() as conn: | |
| result = conn.execute(text("SELECT 1")) | |
| result.fetchone() | |
| print("✓ 数据库连接成功") | |
| return True | |
| except Exception as e: | |
| print(f"✗ 数据库连接失败: {str(e)}") | |
| return False | |
| def check_tables(): | |
| """检查数据库表""" | |
| print("\n检查数据库表...") | |
| try: | |
| inspector = inspect(engine) | |
| tables = inspector.get_table_names() | |
| expected_tables = [ | |
| "users", | |
| "optimization_sessions", | |
| "optimization_segments", | |
| "session_history", | |
| "change_logs", | |
| "queue_status", | |
| "system_settings", | |
| "custom_prompts" | |
| ] | |
| missing_tables = [t for t in expected_tables if t not in tables] | |
| if missing_tables: | |
| print(f"⚠ 缺少以下表: {', '.join(missing_tables)}") | |
| return False | |
| else: | |
| print(f"✓ 所有必需的表都存在 ({len(expected_tables)} 个)") | |
| return True | |
| except Exception as e: | |
| print(f"✗ 检查表失败: {str(e)}") | |
| return False | |
| def display_table_info(): | |
| """显示表信息""" | |
| print("\n数据库表信息:") | |
| print("-" * 60) | |
| try: | |
| inspector = inspect(engine) | |
| tables = inspector.get_table_names() | |
| for table_name in sorted(tables): | |
| columns = inspector.get_columns(table_name) | |
| print(f"\n📊 {table_name} ({len(columns)} 列)") | |
| for col in columns[:5]: # 只显示前5列 | |
| col_type = str(col['type']) | |
| nullable = "NULL" if col['nullable'] else "NOT NULL" | |
| print(f" - {col['name']}: {col_type} {nullable}") | |
| if len(columns) > 5: | |
| print(f" ... 还有 {len(columns) - 5} 列") | |
| except Exception as e: | |
| print(f"✗ 获取表信息失败: {str(e)}") | |
| def check_data_integrity(): | |
| """检查数据完整性""" | |
| print("\n检查数据完整性...") | |
| try: | |
| db = SessionLocal() | |
| try: | |
| # 检查用户数量 | |
| user_count = db.query(User).count() | |
| print(f"✓ 用户数量: {user_count}") | |
| # 检查会话数量 | |
| session_count = db.query(OptimizationSession).count() | |
| print(f"✓ 会话数量: {session_count}") | |
| # 检查系统提示词 | |
| system_prompts = db.query(CustomPrompt).filter(CustomPrompt.is_system == True).count() | |
| print(f"✓ 系统提示词数量: {system_prompts}") | |
| return True | |
| finally: | |
| db.close() | |
| except Exception as e: | |
| print(f"✗ 数据完整性检查失败: {str(e)}") | |
| return False | |
| def test_crud_operations(): | |
| """测试基本的 CRUD 操作""" | |
| print("\n测试数据库操作...") | |
| try: | |
| db = SessionLocal() | |
| try: | |
| # 测试创建 | |
| test_setting = SystemSetting( | |
| key="test_key_delete_me", | |
| value="test_value" | |
| ) | |
| db.add(test_setting) | |
| db.commit() | |
| print("✓ CREATE 操作成功") | |
| # 测试读取 | |
| setting = db.query(SystemSetting).filter( | |
| SystemSetting.key == "test_key_delete_me" | |
| ).first() | |
| if setting: | |
| print("✓ READ 操作成功") | |
| # 测试更新 | |
| setting.value = "updated_value" | |
| db.commit() | |
| print("✓ UPDATE 操作成功") | |
| # 测试删除 | |
| db.delete(setting) | |
| db.commit() | |
| print("✓ DELETE 操作成功") | |
| return True | |
| finally: | |
| db.close() | |
| except Exception as e: | |
| print(f"✗ CRUD 操作测试失败: {str(e)}") | |
| return False | |
| def main(): | |
| """主函数""" | |
| print("=" * 60) | |
| print("数据库初始化和健康检查") | |
| print("=" * 60) | |
| # 检查环境变量 | |
| env_file = backend_dir / ".env" | |
| if not env_file.exists(): | |
| print(f"\n⚠ 警告: 未找到 .env 文件") | |
| print(f" 预期位置: {env_file}") | |
| print(" 将使用默认配置\n") | |
| # 1. 检查数据库连接 | |
| if not check_database_connection(): | |
| print("\n❌ 数据库连接失败,无法继续") | |
| sys.exit(1) | |
| # 2. 初始化数据库 | |
| print("\n" + "=" * 60) | |
| print("初始化数据库...") | |
| print("=" * 60) | |
| try: | |
| init_db() | |
| except Exception as e: | |
| print(f"\n❌ 数据库初始化失败: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| sys.exit(1) | |
| # 3. 检查表 | |
| if not check_tables(): | |
| print("\n⚠ 警告: 某些表缺失") | |
| # 4. 显示表信息 | |
| display_table_info() | |
| # 5. 检查数据完整性 | |
| check_data_integrity() | |
| # 6. 测试 CRUD 操作 | |
| test_crud_operations() | |
| # 总结 | |
| print("\n" + "=" * 60) | |
| print("✓ 数据库检查完成!") | |
| print("=" * 60) | |
| print("\n数据库已就绪,可以启动应用") | |
| # 显示数据库文件位置 | |
| from app.config import settings | |
| if "sqlite" in settings.DATABASE_URL: | |
| db_path = settings.DATABASE_URL.replace("sqlite:///", "") | |
| if db_path.startswith("./"): | |
| db_path = backend_dir / db_path[2:] | |
| else: | |
| db_path = Path(db_path) | |
| if db_path.exists(): | |
| size_mb = db_path.stat().st_size / (1024 * 1024) | |
| print(f"\n📁 数据库文件: {db_path}") | |
| print(f" 大小: {size_mb:.2f} MB") | |
| if __name__ == "__main__": | |
| try: | |
| main() | |
| except KeyboardInterrupt: | |
| print("\n\n⚠ 用户中断") | |
| sys.exit(0) | |
| except Exception as e: | |
| print(f"\n❌ 发生错误: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| sys.exit(1) | |