cacode's picture
Upload 74 files
7c15d35 verified
from sqlalchemy import create_engine, inspect, text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from app.config import settings
engine = create_engine(
settings.DATABASE_URL,
connect_args={"check_same_thread": False} if "sqlite" in settings.DATABASE_URL else {}
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
def get_db():
"""数据库会话依赖"""
db = SessionLocal()
try:
yield db
finally:
db.close()
def init_db():
"""初始化数据库 - 安全地创建或更新数据库结构"""
try:
# 导入所有模型以确保它们被注册到 Base.metadata
from app.models import models # noqa: F401
# 创建所有表(如果不存在)
Base.metadata.create_all(bind=engine)
# 检查并添加可能缺失的列(用于数据库迁移)
_migrate_database_schema()
# 自动添加性能优化索引
_add_performance_indexes()
print("✓ 数据库初始化成功")
return True
except Exception as e:
print(f"✗ 数据库初始化失败: {str(e)}")
raise
def _add_column_safely(conn, table_name, column_name, column_def):
"""安全地添加列(如果不存在)"""
try:
conn.execute(text(f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_def}"))
conn.commit()
return True
except Exception as e:
# 列可能已存在或其他错误
conn.rollback()
return False
def _add_performance_indexes():
"""添加性能优化索引"""
try:
inspector = inspect(engine)
tables = inspector.get_table_names()
# 定义需要的索引
indexes = [
# OptimizationSession indexes
("idx_opt_session_user_id", "optimization_sessions", "user_id"),
("idx_opt_session_status", "optimization_sessions", "status"),
("idx_opt_session_created_at", "optimization_sessions", "created_at"),
# OptimizationSegment indexes
("idx_opt_segment_session_id", "optimization_segments", "session_id"),
("idx_opt_segment_index", "optimization_segments", "segment_index"),
("idx_opt_segment_status", "optimization_segments", "status"),
# ChangeLog indexes
("idx_change_log_session_id", "change_logs", "session_id"),
("idx_change_log_segment_index", "change_logs", "segment_index"),
("idx_change_log_stage", "change_logs", "stage"),
]
with engine.connect() as conn:
for index_name, table_name, column_name in indexes:
# 检查表是否存在
if table_name not in tables:
continue
try:
# 获取表上现有的索引
existing_indexes = inspector.get_indexes(table_name)
index_names = {idx['name'] for idx in existing_indexes}
# 如果索引已存在,跳过
if index_name in index_names:
continue
# 创建索引(SQLite 和 PostgreSQL 都支持相同语法)
conn.execute(text(
f"CREATE INDEX IF NOT EXISTS {index_name} ON {table_name} ({column_name})"
))
conn.commit()
print(f" ✓ 添加索引: {index_name}")
except Exception as e:
# 索引可能已存在或其他错误
conn.rollback()
# 静默失败,不阻止应用启动
pass
except Exception as e:
print(f" ⚠ 添加性能索引警告: {str(e)}")
# 失败不应该阻止应用启动
def _migrate_database_schema():
"""迁移数据库结构 - 添加新列到已存在的表"""
try:
inspector = inspect(engine)
# 检查表是否存在
tables = inspector.get_table_names()
with engine.connect() as conn:
# 迁移 optimization_sessions 表
if "optimization_sessions" in tables:
columns = {column["name"] for column in inspector.get_columns("optimization_sessions")}
if "failed_segment_index" not in columns:
if _add_column_safely(conn, "optimization_sessions", "failed_segment_index", "INTEGER"):
print(" ✓ 添加字段: optimization_sessions.failed_segment_index")
if "processing_mode" not in columns:
if _add_column_safely(conn, "optimization_sessions", "processing_mode", "VARCHAR(50) DEFAULT 'paper_polish_enhance'"):
print(" ✓ 添加字段: optimization_sessions.processing_mode")
if "emotion_model" not in columns:
added = _add_column_safely(conn, "optimization_sessions", "emotion_model", "VARCHAR(100)")
_add_column_safely(conn, "optimization_sessions", "emotion_api_key", "VARCHAR(255)")
_add_column_safely(conn, "optimization_sessions", "emotion_base_url", "VARCHAR(255)")
if added:
print(" ✓ 添加字段: optimization_sessions.emotion_* 字段")
# 迁移 users 表
if "users" in tables:
user_columns = {column["name"] for column in inspector.get_columns("users")}
if "usage_limit" not in user_columns:
if _add_column_safely(conn, "users", "usage_limit", f"INTEGER DEFAULT {settings.DEFAULT_USAGE_LIMIT}"):
print(" ✓ 添加字段: users.usage_limit")
if "usage_count" not in user_columns:
if _add_column_safely(conn, "users", "usage_count", "INTEGER DEFAULT 0"):
print(" ✓ 添加字段: users.usage_count")
# 更新 NULL 值
try:
conn.execute(text(f"UPDATE users SET usage_limit = {settings.DEFAULT_USAGE_LIMIT} WHERE usage_limit IS NULL"))
conn.execute(text("UPDATE users SET usage_count = 0 WHERE usage_count IS NULL"))
conn.commit()
except Exception:
conn.rollback()
# 迁移 optimization_segments 表
if "optimization_segments" in tables:
segment_columns = {column["name"] for column in inspector.get_columns("optimization_segments")}
if "is_title" not in segment_columns:
if _add_column_safely(conn, "optimization_segments", "is_title", "BOOLEAN DEFAULT 0"):
print(" ✓ 添加字段: optimization_segments.is_title")
# 迁移 custom_prompts 表
if "custom_prompts" in tables:
prompt_columns = {column["name"] for column in inspector.get_columns("custom_prompts")}
if "is_system" not in prompt_columns:
if _add_column_safely(conn, "custom_prompts", "is_system", "BOOLEAN DEFAULT 0"):
print(" ✓ 添加字段: custom_prompts.is_system")
if "is_active" not in prompt_columns:
if _add_column_safely(conn, "custom_prompts", "is_active", "BOOLEAN DEFAULT 1"):
print(" ✓ 添加字段: custom_prompts.is_active")
except Exception as e:
print(f" ⚠ 数据库迁移警告: {str(e)}")
# 迁移失败不应该阻止应用启动