File size: 10,245 Bytes
35e7795
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
#!/usr/bin/env python
"""数据库迁移脚本 - 为用户表添加单位、团队、物种字段

确保用户表包含以下新字段:
1. organization (VARCHAR) - 单位(崖州湾实验室、之江实验室)
2. team (VARCHAR) - 团队
3. species (VARCHAR) - 物种

这个脚本可以安全地多次运行(幂等性)。

主要变更:
1. 为 users 表添加 organization 列(如果不存在)
2. 为 users 表添加 team 列(如果不存在)
3. 为 users 表添加 species 列(如果不存在)
"""

import sys
from pathlib import Path

# 添加项目根目录到路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))

# 注意:以下导入必须在 sys.path.insert 之后,因为需要导入项目模块
import logging  # noqa: E402
import shutil  # noqa: E402

from sqlalchemy import inspect, text  # noqa: E402
from sqlalchemy.engine import Engine  # noqa: E402

from qa_annotate.database.base import DB_PATH, engine  # noqa: E402

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)


def backup_database():
    """备份数据库(SQLite)"""
    from datetime import datetime

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    backup_path = DB_PATH.parent / f"annotations_backup_{timestamp}.db"
    shutil.copy2(DB_PATH, backup_path)
    logger.info(f"SQLite 数据库已备份到: {backup_path}")
    return backup_path


def disable_foreign_keys(conn):
    """禁用外键约束"""
    conn.execute(text("PRAGMA foreign_keys = OFF"))
    logger.info("已禁用外键约束")


def enable_foreign_keys(conn):
    """启用外键约束"""
    conn.execute(text("PRAGMA foreign_keys = ON"))
    logger.info("已启用外键约束")


def restore_from_backup(backup_path: Path):
    """从备份恢复数据库(SQLite)"""
    logger.warning("=" * 60)
    logger.warning("开始从备份恢复数据库...")
    logger.warning(f"备份路径: {backup_path}")
    logger.warning("=" * 60)

    if not backup_path.exists():
        logger.error(f"备份文件不存在: {backup_path}")
        return False

    try:
        shutil.copy2(backup_path, DB_PATH)
        logger.info("SQLite 数据库已从备份恢复")
        return True
    except Exception as e:
        logger.error(f"恢复数据库失败: {e}")
        return False


def table_exists(inspector: inspect, table_name: str) -> bool:
    """检查表是否存在"""
    return table_name in inspector.get_table_names()


def column_exists(inspector: inspect, table_name: str, column_name: str) -> bool:
    """检查列是否存在"""
    if not table_exists(inspector, table_name):
        return False
    columns = inspector.get_columns(table_name)
    return any(col["name"] == column_name for col in columns)


def add_user_fields(engine: Engine):
    """为用户表添加新字段(organization、team、species)"""
    logger.info("检查 users 表的新字段...")

    inspector = inspect(engine)

    if not table_exists(inspector, "users"):
        logger.error("users 表不存在,无法添加字段")
        return False

    changes_made = False

    with engine.begin() as conn:
        disable_foreign_keys(conn)

        # 添加 organization 列
        if not column_exists(inspector, "users", "organization"):
            logger.info("添加 organization 列...")
            try:
                conn.execute(
                    text("""
                    ALTER TABLE users
                    ADD COLUMN organization VARCHAR
                """)
                )
                logger.info("organization 列已添加")
                changes_made = True
            except Exception as e:
                logger.warning(f"无法通过 ALTER TABLE 添加 organization 列: {e}")
                logger.error("添加 organization 列失败")
        else:
            logger.info("organization 列已存在,跳过")

        # 添加 team 列
        if not column_exists(inspector, "users", "team"):
            logger.info("添加 team 列...")
            try:
                conn.execute(
                    text("""
                    ALTER TABLE users
                    ADD COLUMN team VARCHAR
                """)
                )
                logger.info("team 列已添加")
                changes_made = True
            except Exception as e:
                logger.warning(f"无法通过 ALTER TABLE 添加 team 列: {e}")
                logger.error("添加 team 列失败")
        else:
            logger.info("team 列已存在,跳过")

        # 添加 species 列
        if not column_exists(inspector, "users", "species"):
            logger.info("添加 species 列...")
            try:
                conn.execute(
                    text("""
                    ALTER TABLE users
                    ADD COLUMN species VARCHAR
                """)
                )
                logger.info("species 列已添加")
                changes_made = True
            except Exception as e:
                logger.warning(f"无法通过 ALTER TABLE 添加 species 列: {e}")
                logger.error("添加 species 列失败")
        else:
            logger.info("species 列已存在,跳过")

        enable_foreign_keys(conn)

    return changes_made


def validate_migration(engine: Engine):
    """验证迁移结果"""
    logger.info("验证迁移结果...")

    inspector = inspect(engine)
    errors = []

    # 检查 users 表是否存在
    if not table_exists(inspector, "users"):
        errors.append("users 表不存在")
        return False, errors

    logger.info("✓ users 表存在")

    # 检查新字段
    required_columns = ["organization", "team", "species"]
    columns = {col["name"] for col in inspector.get_columns("users")}

    for col in required_columns:
        if col not in columns:
            errors.append(f"users 表缺少列: {col}")
        else:
            logger.info(f"  ✓ users 表有 {col} 列")

    if errors:
        logger.error("验证失败:")
        for error in errors:
            logger.error(f"  - {error}")
        return False, errors

    logger.info("验证通过")
    return True, []


def main():
    """主函数"""
    logger.info("=" * 60)
    logger.info("开始数据库迁移 - 为用户表添加单位、团队、物种字段")
    logger.info(f"数据库路径: {DB_PATH}")
    logger.info("=" * 60)

    # 检查数据库文件是否存在
    if not DB_PATH.exists():
        logger.error(f"数据库文件不存在: {DB_PATH}")
        logger.info("如果这是新安装,请先运行应用以初始化数据库")
        sys.exit(1)

    # 备份数据库
    logger.info("备份数据库...")
    backup_path = backup_database()
    logger.info(f"备份完成: {backup_path}")

    migration_success = False
    changes_made = False

    try:
        # 为用户表添加新字段
        if add_user_fields(engine):
            changes_made = True

        if not changes_made:
            logger.info("=" * 60)
            logger.info("所有字段已存在,无需迁移")
            logger.info("=" * 60)
            return

        # 验证迁移
        logger.info("=" * 60)
        logger.info("开始验证迁移结果...")
        logger.info("=" * 60)

        is_valid, errors = validate_migration(engine)

        if not is_valid:
            logger.error("=" * 60)
            logger.error("迁移验证失败!")
            logger.error("=" * 60)
            logger.error("错误详情:")
            for error in errors:
                logger.error(f"  - {error}")
            logger.error("=" * 60)
            logger.error("开始回退到备份...")
            logger.error("=" * 60)

            if restore_from_backup(backup_path):
                logger.error("已成功回退到备份")
            else:
                logger.error("回退失败,请手动恢复数据库")
            sys.exit(1)

        migration_success = True

        logger.info("=" * 60)
        logger.info("数据库迁移完成!")
        logger.info("验证通过!")
        logger.info("=" * 60)

        # 用户确认
        logger.info("=" * 60)
        logger.info("迁移已完成,请确认是否接受此次迁移")
        logger.info(f"备份文件位置: {backup_path}")
        logger.info("=" * 60)

        try:
            user_input = input("确认接受迁移?(y/n): ").strip().lower()
            if user_input in ("y", "yes"):
                logger.info("=" * 60)
                logger.info("用户已确认,迁移完成!")
                logger.info("=" * 60)
            else:
                logger.warning("=" * 60)
                logger.warning("用户已取消,开始回退到备份...")
                logger.warning("=" * 60)

                if restore_from_backup(backup_path):
                    logger.warning("已成功回退到备份")
                    logger.warning("迁移已取消")
                else:
                    logger.error("回退失败,请手动恢复数据库")
                sys.exit(0)

        except (KeyboardInterrupt, EOFError):
            logger.warning("")
            logger.warning("=" * 60)
            logger.warning("用户中断操作,开始回退到备份...")
            logger.warning("=" * 60)

            if restore_from_backup(backup_path):
                logger.warning("已成功回退到备份")
                logger.warning("迁移已取消")
            else:
                logger.error("回退失败,请手动恢复数据库")
            sys.exit(0)

    except Exception as e:
        logger.error("=" * 60)
        logger.error(f"迁移过程中发生异常: {e}", exc_info=True)
        logger.error("=" * 60)

        if not migration_success:
            logger.error("开始回退到备份...")
            if restore_from_backup(backup_path):
                logger.error("已成功回退到备份")
            else:
                logger.error("回退失败,请手动恢复数据库")
        sys.exit(1)


if __name__ == "__main__":
    main()