File size: 3,603 Bytes
9c50399
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""Tiny forward/!down migration runner for SQLite.



Each file in /migrations has a ``-- up`` section and a ``-- down`` section.

Applied migrations are tracked in ``schema_migrations``.



Usage:

    python -m scripts.migrate up        # apply all pending

    python -m scripts.migrate down       # revert the most recent

    python -m scripts.migrate down 0     # revert everything (to a clean DB)

    python -m scripts.migrate status

"""

import os
import sys

from app import db

MIGRATIONS_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "migrations")


def _parse(path: str) -> tuple[str, str]:
    with open(path, encoding="utf-8") as f:
        text = f.read()
    lower = text.lower()
    up_idx = lower.find("-- up")
    down_idx = lower.find("-- down")
    if up_idx == -1 or down_idx == -1:
        raise ValueError(f"{path} must contain '-- up' and '-- down' markers")
    up_sql = text[up_idx + len("-- up"):down_idx].strip()
    down_sql = text[down_idx + len("-- down"):].strip()
    return up_sql, down_sql


def _all_migrations() -> list[str]:
    return sorted(f for f in os.listdir(MIGRATIONS_DIR) if f.endswith(".sql"))


def _ensure_table(conn) -> None:
    conn.executescript(
        "CREATE TABLE IF NOT EXISTS schema_migrations ("
        " name TEXT PRIMARY KEY,"
        " applied_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')));"
    )


def _applied(conn) -> set[str]:
    _ensure_table(conn)
    return {r[0] for r in conn.execute("SELECT name FROM schema_migrations").fetchall()}


def up() -> list[str]:
    done = []
    conn = db.connect()
    try:
        applied = _applied(conn)
        for name in _all_migrations():
            if name in applied:
                continue
            up_sql, _ = _parse(os.path.join(MIGRATIONS_DIR, name))
            conn.executescript(up_sql)
            conn.execute("INSERT INTO schema_migrations (name) VALUES (?)", (name,))
            conn.commit()
            done.append(name)
        return done
    finally:
        conn.close()


def down(target: int | None = None) -> list[str]:
    """Revert migrations. ``target`` is how many to keep applied (default: keep all-but-last)."""
    reverted = []
    conn = db.connect()
    try:
        applied = sorted(_applied(conn))
        if target is None:
            to_revert = applied[-1:] if applied else []
        else:
            to_revert = applied[target:]
        for name in reversed(to_revert):
            _, down_sql = _parse(os.path.join(MIGRATIONS_DIR, name))
            conn.executescript(down_sql)
            conn.execute("DELETE FROM schema_migrations WHERE name = ?", (name,))
            conn.commit()
            reverted.append(name)
        return reverted
    finally:
        conn.close()


def status() -> None:
    conn = db.connect()
    try:
        applied = _applied(conn)
        for name in _all_migrations():
            print(f"[{'x' if name in applied else ' '}] {name}")
    finally:
        conn.close()


def main(argv: list[str]) -> int:
    cmd = argv[1] if len(argv) > 1 else "up"
    if cmd == "up":
        print("applied:", up() or "(nothing pending)")
    elif cmd == "down":
        tgt = int(argv[2]) if len(argv) > 2 else None
        print("reverted:", down(tgt) or "(nothing to revert)")
    elif cmd == "status":
        status()
    else:
        print(__doc__)
        return 2
    return 0


if __name__ == "__main__":
    raise SystemExit(main(sys.argv))