from sqlalchemy import create_engine, inspect, text from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from app.config import settings engine = create_engine( settings.DATABASE_URL, connect_args={"check_same_thread": False} if "sqlite" in settings.DATABASE_URL else {} ) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() def get_db(): """数据库会话依赖""" db = SessionLocal() try: yield db finally: db.close() def init_db(): """初始化数据库 - 安全地创建或更新数据库结构""" try: # 导入所有模型以确保它们被注册到 Base.metadata from app.models import models # noqa: F401 # 创建所有表(如果不存在) Base.metadata.create_all(bind=engine) # 检查并添加可能缺失的列(用于数据库迁移) _migrate_database_schema() # 自动添加性能优化索引 _add_performance_indexes() print("✓ 数据库初始化成功") return True except Exception as e: print(f"✗ 数据库初始化失败: {str(e)}") raise def _add_column_safely(conn, table_name, column_name, column_def): """安全地添加列(如果不存在)""" try: conn.execute(text(f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_def}")) conn.commit() return True except Exception as e: # 列可能已存在或其他错误 conn.rollback() return False def _add_performance_indexes(): """添加性能优化索引""" try: inspector = inspect(engine) tables = inspector.get_table_names() # 定义需要的索引 indexes = [ # OptimizationSession indexes ("idx_opt_session_user_id", "optimization_sessions", "user_id"), ("idx_opt_session_status", "optimization_sessions", "status"), ("idx_opt_session_created_at", "optimization_sessions", "created_at"), # OptimizationSegment indexes ("idx_opt_segment_session_id", "optimization_segments", "session_id"), ("idx_opt_segment_index", "optimization_segments", "segment_index"), ("idx_opt_segment_status", "optimization_segments", "status"), # ChangeLog indexes ("idx_change_log_session_id", "change_logs", "session_id"), ("idx_change_log_segment_index", "change_logs", "segment_index"), ("idx_change_log_stage", "change_logs", "stage"), ] with engine.connect() as conn: for index_name, table_name, column_name in indexes: # 检查表是否存在 if table_name not in tables: continue try: # 获取表上现有的索引 existing_indexes = inspector.get_indexes(table_name) index_names = {idx['name'] for idx in existing_indexes} # 如果索引已存在,跳过 if index_name in index_names: continue # 创建索引(SQLite 和 PostgreSQL 都支持相同语法) conn.execute(text( f"CREATE INDEX IF NOT EXISTS {index_name} ON {table_name} ({column_name})" )) conn.commit() print(f" ✓ 添加索引: {index_name}") except Exception as e: # 索引可能已存在或其他错误 conn.rollback() # 静默失败,不阻止应用启动 pass except Exception as e: print(f" ⚠ 添加性能索引警告: {str(e)}") # 失败不应该阻止应用启动 def _migrate_database_schema(): """迁移数据库结构 - 添加新列到已存在的表""" try: inspector = inspect(engine) # 检查表是否存在 tables = inspector.get_table_names() with engine.connect() as conn: # 迁移 optimization_sessions 表 if "optimization_sessions" in tables: columns = {column["name"] for column in inspector.get_columns("optimization_sessions")} if "failed_segment_index" not in columns: if _add_column_safely(conn, "optimization_sessions", "failed_segment_index", "INTEGER"): print(" ✓ 添加字段: optimization_sessions.failed_segment_index") if "processing_mode" not in columns: if _add_column_safely(conn, "optimization_sessions", "processing_mode", "VARCHAR(50) DEFAULT 'paper_polish_enhance'"): print(" ✓ 添加字段: optimization_sessions.processing_mode") if "emotion_model" not in columns: added = _add_column_safely(conn, "optimization_sessions", "emotion_model", "VARCHAR(100)") _add_column_safely(conn, "optimization_sessions", "emotion_api_key", "VARCHAR(255)") _add_column_safely(conn, "optimization_sessions", "emotion_base_url", "VARCHAR(255)") if added: print(" ✓ 添加字段: optimization_sessions.emotion_* 字段") # 迁移 users 表 if "users" in tables: user_columns = {column["name"] for column in inspector.get_columns("users")} if "usage_limit" not in user_columns: if _add_column_safely(conn, "users", "usage_limit", f"INTEGER DEFAULT {settings.DEFAULT_USAGE_LIMIT}"): print(" ✓ 添加字段: users.usage_limit") if "usage_count" not in user_columns: if _add_column_safely(conn, "users", "usage_count", "INTEGER DEFAULT 0"): print(" ✓ 添加字段: users.usage_count") # 更新 NULL 值 try: conn.execute(text(f"UPDATE users SET usage_limit = {settings.DEFAULT_USAGE_LIMIT} WHERE usage_limit IS NULL")) conn.execute(text("UPDATE users SET usage_count = 0 WHERE usage_count IS NULL")) conn.commit() except Exception: conn.rollback() # 迁移 optimization_segments 表 if "optimization_segments" in tables: segment_columns = {column["name"] for column in inspector.get_columns("optimization_segments")} if "is_title" not in segment_columns: if _add_column_safely(conn, "optimization_segments", "is_title", "BOOLEAN DEFAULT 0"): print(" ✓ 添加字段: optimization_segments.is_title") # 迁移 custom_prompts 表 if "custom_prompts" in tables: prompt_columns = {column["name"] for column in inspector.get_columns("custom_prompts")} if "is_system" not in prompt_columns: if _add_column_safely(conn, "custom_prompts", "is_system", "BOOLEAN DEFAULT 0"): print(" ✓ 添加字段: custom_prompts.is_system") if "is_active" not in prompt_columns: if _add_column_safely(conn, "custom_prompts", "is_active", "BOOLEAN DEFAULT 1"): print(" ✓ 添加字段: custom_prompts.is_active") except Exception as e: print(f" ⚠ 数据库迁移警告: {str(e)}") # 迁移失败不应该阻止应用启动