Spaces:
Running
Running
File size: 8,442 Bytes
7c15d35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
from sqlalchemy import create_engine, inspect, text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from app.config import settings
engine = create_engine(
settings.DATABASE_URL,
connect_args={"check_same_thread": False} if "sqlite" in settings.DATABASE_URL else {}
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
def get_db():
"""数据库会话依赖"""
db = SessionLocal()
try:
yield db
finally:
db.close()
def init_db():
"""初始化数据库 - 安全地创建或更新数据库结构"""
try:
# 导入所有模型以确保它们被注册到 Base.metadata
from app.models import models # noqa: F401
# 创建所有表(如果不存在)
Base.metadata.create_all(bind=engine)
# 检查并添加可能缺失的列(用于数据库迁移)
_migrate_database_schema()
# 自动添加性能优化索引
_add_performance_indexes()
print("✓ 数据库初始化成功")
return True
except Exception as e:
print(f"✗ 数据库初始化失败: {str(e)}")
raise
def _add_column_safely(conn, table_name, column_name, column_def):
"""安全地添加列(如果不存在)"""
try:
conn.execute(text(f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_def}"))
conn.commit()
return True
except Exception as e:
# 列可能已存在或其他错误
conn.rollback()
return False
def _add_performance_indexes():
"""添加性能优化索引"""
try:
inspector = inspect(engine)
tables = inspector.get_table_names()
# 定义需要的索引
indexes = [
# OptimizationSession indexes
("idx_opt_session_user_id", "optimization_sessions", "user_id"),
("idx_opt_session_status", "optimization_sessions", "status"),
("idx_opt_session_created_at", "optimization_sessions", "created_at"),
# OptimizationSegment indexes
("idx_opt_segment_session_id", "optimization_segments", "session_id"),
("idx_opt_segment_index", "optimization_segments", "segment_index"),
("idx_opt_segment_status", "optimization_segments", "status"),
# ChangeLog indexes
("idx_change_log_session_id", "change_logs", "session_id"),
("idx_change_log_segment_index", "change_logs", "segment_index"),
("idx_change_log_stage", "change_logs", "stage"),
]
with engine.connect() as conn:
for index_name, table_name, column_name in indexes:
# 检查表是否存在
if table_name not in tables:
continue
try:
# 获取表上现有的索引
existing_indexes = inspector.get_indexes(table_name)
index_names = {idx['name'] for idx in existing_indexes}
# 如果索引已存在,跳过
if index_name in index_names:
continue
# 创建索引(SQLite 和 PostgreSQL 都支持相同语法)
conn.execute(text(
f"CREATE INDEX IF NOT EXISTS {index_name} ON {table_name} ({column_name})"
))
conn.commit()
print(f" ✓ 添加索引: {index_name}")
except Exception as e:
# 索引可能已存在或其他错误
conn.rollback()
# 静默失败,不阻止应用启动
pass
except Exception as e:
print(f" ⚠ 添加性能索引警告: {str(e)}")
# 失败不应该阻止应用启动
def _migrate_database_schema():
"""迁移数据库结构 - 添加新列到已存在的表"""
try:
inspector = inspect(engine)
# 检查表是否存在
tables = inspector.get_table_names()
with engine.connect() as conn:
# 迁移 optimization_sessions 表
if "optimization_sessions" in tables:
columns = {column["name"] for column in inspector.get_columns("optimization_sessions")}
if "failed_segment_index" not in columns:
if _add_column_safely(conn, "optimization_sessions", "failed_segment_index", "INTEGER"):
print(" ✓ 添加字段: optimization_sessions.failed_segment_index")
if "processing_mode" not in columns:
if _add_column_safely(conn, "optimization_sessions", "processing_mode", "VARCHAR(50) DEFAULT 'paper_polish_enhance'"):
print(" ✓ 添加字段: optimization_sessions.processing_mode")
if "emotion_model" not in columns:
added = _add_column_safely(conn, "optimization_sessions", "emotion_model", "VARCHAR(100)")
_add_column_safely(conn, "optimization_sessions", "emotion_api_key", "VARCHAR(255)")
_add_column_safely(conn, "optimization_sessions", "emotion_base_url", "VARCHAR(255)")
if added:
print(" ✓ 添加字段: optimization_sessions.emotion_* 字段")
# 迁移 users 表
if "users" in tables:
user_columns = {column["name"] for column in inspector.get_columns("users")}
if "usage_limit" not in user_columns:
if _add_column_safely(conn, "users", "usage_limit", f"INTEGER DEFAULT {settings.DEFAULT_USAGE_LIMIT}"):
print(" ✓ 添加字段: users.usage_limit")
if "usage_count" not in user_columns:
if _add_column_safely(conn, "users", "usage_count", "INTEGER DEFAULT 0"):
print(" ✓ 添加字段: users.usage_count")
# 更新 NULL 值
try:
conn.execute(text(f"UPDATE users SET usage_limit = {settings.DEFAULT_USAGE_LIMIT} WHERE usage_limit IS NULL"))
conn.execute(text("UPDATE users SET usage_count = 0 WHERE usage_count IS NULL"))
conn.commit()
except Exception:
conn.rollback()
# 迁移 optimization_segments 表
if "optimization_segments" in tables:
segment_columns = {column["name"] for column in inspector.get_columns("optimization_segments")}
if "is_title" not in segment_columns:
if _add_column_safely(conn, "optimization_segments", "is_title", "BOOLEAN DEFAULT 0"):
print(" ✓ 添加字段: optimization_segments.is_title")
# 迁移 custom_prompts 表
if "custom_prompts" in tables:
prompt_columns = {column["name"] for column in inspector.get_columns("custom_prompts")}
if "is_system" not in prompt_columns:
if _add_column_safely(conn, "custom_prompts", "is_system", "BOOLEAN DEFAULT 0"):
print(" ✓ 添加字段: custom_prompts.is_system")
if "is_active" not in prompt_columns:
if _add_column_safely(conn, "custom_prompts", "is_active", "BOOLEAN DEFAULT 1"):
print(" ✓ 添加字段: custom_prompts.is_active")
except Exception as e:
print(f" ⚠ 数据库迁移警告: {str(e)}")
# 迁移失败不应该阻止应用启动
|