File size: 10,245 Bytes
35e7795 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 | #!/usr/bin/env python
"""数据库迁移脚本 - 为用户表添加单位、团队、物种字段
确保用户表包含以下新字段:
1. organization (VARCHAR) - 单位(崖州湾实验室、之江实验室)
2. team (VARCHAR) - 团队
3. species (VARCHAR) - 物种
这个脚本可以安全地多次运行(幂等性)。
主要变更:
1. 为 users 表添加 organization 列(如果不存在)
2. 为 users 表添加 team 列(如果不存在)
3. 为 users 表添加 species 列(如果不存在)
"""
import sys
from pathlib import Path
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
# 注意:以下导入必须在 sys.path.insert 之后,因为需要导入项目模块
import logging # noqa: E402
import shutil # noqa: E402
from sqlalchemy import inspect, text # noqa: E402
from sqlalchemy.engine import Engine # noqa: E402
from qa_annotate.database.base import DB_PATH, engine # noqa: E402
# 配置日志
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
def backup_database():
"""备份数据库(SQLite)"""
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = DB_PATH.parent / f"annotations_backup_{timestamp}.db"
shutil.copy2(DB_PATH, backup_path)
logger.info(f"SQLite 数据库已备份到: {backup_path}")
return backup_path
def disable_foreign_keys(conn):
"""禁用外键约束"""
conn.execute(text("PRAGMA foreign_keys = OFF"))
logger.info("已禁用外键约束")
def enable_foreign_keys(conn):
"""启用外键约束"""
conn.execute(text("PRAGMA foreign_keys = ON"))
logger.info("已启用外键约束")
def restore_from_backup(backup_path: Path):
"""从备份恢复数据库(SQLite)"""
logger.warning("=" * 60)
logger.warning("开始从备份恢复数据库...")
logger.warning(f"备份路径: {backup_path}")
logger.warning("=" * 60)
if not backup_path.exists():
logger.error(f"备份文件不存在: {backup_path}")
return False
try:
shutil.copy2(backup_path, DB_PATH)
logger.info("SQLite 数据库已从备份恢复")
return True
except Exception as e:
logger.error(f"恢复数据库失败: {e}")
return False
def table_exists(inspector: inspect, table_name: str) -> bool:
"""检查表是否存在"""
return table_name in inspector.get_table_names()
def column_exists(inspector: inspect, table_name: str, column_name: str) -> bool:
"""检查列是否存在"""
if not table_exists(inspector, table_name):
return False
columns = inspector.get_columns(table_name)
return any(col["name"] == column_name for col in columns)
def add_user_fields(engine: Engine):
"""为用户表添加新字段(organization、team、species)"""
logger.info("检查 users 表的新字段...")
inspector = inspect(engine)
if not table_exists(inspector, "users"):
logger.error("users 表不存在,无法添加字段")
return False
changes_made = False
with engine.begin() as conn:
disable_foreign_keys(conn)
# 添加 organization 列
if not column_exists(inspector, "users", "organization"):
logger.info("添加 organization 列...")
try:
conn.execute(
text("""
ALTER TABLE users
ADD COLUMN organization VARCHAR
""")
)
logger.info("organization 列已添加")
changes_made = True
except Exception as e:
logger.warning(f"无法通过 ALTER TABLE 添加 organization 列: {e}")
logger.error("添加 organization 列失败")
else:
logger.info("organization 列已存在,跳过")
# 添加 team 列
if not column_exists(inspector, "users", "team"):
logger.info("添加 team 列...")
try:
conn.execute(
text("""
ALTER TABLE users
ADD COLUMN team VARCHAR
""")
)
logger.info("team 列已添加")
changes_made = True
except Exception as e:
logger.warning(f"无法通过 ALTER TABLE 添加 team 列: {e}")
logger.error("添加 team 列失败")
else:
logger.info("team 列已存在,跳过")
# 添加 species 列
if not column_exists(inspector, "users", "species"):
logger.info("添加 species 列...")
try:
conn.execute(
text("""
ALTER TABLE users
ADD COLUMN species VARCHAR
""")
)
logger.info("species 列已添加")
changes_made = True
except Exception as e:
logger.warning(f"无法通过 ALTER TABLE 添加 species 列: {e}")
logger.error("添加 species 列失败")
else:
logger.info("species 列已存在,跳过")
enable_foreign_keys(conn)
return changes_made
def validate_migration(engine: Engine):
"""验证迁移结果"""
logger.info("验证迁移结果...")
inspector = inspect(engine)
errors = []
# 检查 users 表是否存在
if not table_exists(inspector, "users"):
errors.append("users 表不存在")
return False, errors
logger.info("✓ users 表存在")
# 检查新字段
required_columns = ["organization", "team", "species"]
columns = {col["name"] for col in inspector.get_columns("users")}
for col in required_columns:
if col not in columns:
errors.append(f"users 表缺少列: {col}")
else:
logger.info(f" ✓ users 表有 {col} 列")
if errors:
logger.error("验证失败:")
for error in errors:
logger.error(f" - {error}")
return False, errors
logger.info("验证通过")
return True, []
def main():
"""主函数"""
logger.info("=" * 60)
logger.info("开始数据库迁移 - 为用户表添加单位、团队、物种字段")
logger.info(f"数据库路径: {DB_PATH}")
logger.info("=" * 60)
# 检查数据库文件是否存在
if not DB_PATH.exists():
logger.error(f"数据库文件不存在: {DB_PATH}")
logger.info("如果这是新安装,请先运行应用以初始化数据库")
sys.exit(1)
# 备份数据库
logger.info("备份数据库...")
backup_path = backup_database()
logger.info(f"备份完成: {backup_path}")
migration_success = False
changes_made = False
try:
# 为用户表添加新字段
if add_user_fields(engine):
changes_made = True
if not changes_made:
logger.info("=" * 60)
logger.info("所有字段已存在,无需迁移")
logger.info("=" * 60)
return
# 验证迁移
logger.info("=" * 60)
logger.info("开始验证迁移结果...")
logger.info("=" * 60)
is_valid, errors = validate_migration(engine)
if not is_valid:
logger.error("=" * 60)
logger.error("迁移验证失败!")
logger.error("=" * 60)
logger.error("错误详情:")
for error in errors:
logger.error(f" - {error}")
logger.error("=" * 60)
logger.error("开始回退到备份...")
logger.error("=" * 60)
if restore_from_backup(backup_path):
logger.error("已成功回退到备份")
else:
logger.error("回退失败,请手动恢复数据库")
sys.exit(1)
migration_success = True
logger.info("=" * 60)
logger.info("数据库迁移完成!")
logger.info("验证通过!")
logger.info("=" * 60)
# 用户确认
logger.info("=" * 60)
logger.info("迁移已完成,请确认是否接受此次迁移")
logger.info(f"备份文件位置: {backup_path}")
logger.info("=" * 60)
try:
user_input = input("确认接受迁移?(y/n): ").strip().lower()
if user_input in ("y", "yes"):
logger.info("=" * 60)
logger.info("用户已确认,迁移完成!")
logger.info("=" * 60)
else:
logger.warning("=" * 60)
logger.warning("用户已取消,开始回退到备份...")
logger.warning("=" * 60)
if restore_from_backup(backup_path):
logger.warning("已成功回退到备份")
logger.warning("迁移已取消")
else:
logger.error("回退失败,请手动恢复数据库")
sys.exit(0)
except (KeyboardInterrupt, EOFError):
logger.warning("")
logger.warning("=" * 60)
logger.warning("用户中断操作,开始回退到备份...")
logger.warning("=" * 60)
if restore_from_backup(backup_path):
logger.warning("已成功回退到备份")
logger.warning("迁移已取消")
else:
logger.error("回退失败,请手动恢复数据库")
sys.exit(0)
except Exception as e:
logger.error("=" * 60)
logger.error(f"迁移过程中发生异常: {e}", exc_info=True)
logger.error("=" * 60)
if not migration_success:
logger.error("开始回退到备份...")
if restore_from_backup(backup_path):
logger.error("已成功回退到备份")
else:
logger.error("回退失败,请手动恢复数据库")
sys.exit(1)
if __name__ == "__main__":
main()
|