File size: 4,950 Bytes
a5018da
 
 
 
 
 
 
 
 
 
 
 
2a51199
a5018da
2a51199
a5018da
 
 
 
 
 
 
e03d746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5018da
 
 
 
 
 
 
 
2a51199
 
 
 
 
 
 
a5018da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a51199
a5018da
 
 
2a51199
a5018da
 
2a51199
 
 
 
 
 
 
 
 
a5018da
2a51199
a5018da
 
 
2a51199
a5018da
2a51199
a5018da
 
 
 
2a51199
 
 
 
 
 
a5018da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""
db_sync.py β€” Persists the Label Studio SQLite DB across HF Space rebuilds
by syncing to/from the private TrustLLMeu/saga-db-backup dataset repo.

Usage:
  python3 db_sync.py restore   # pull DB from HF repo β†’ /data/ls/
  python3 db_sync.py backup    # push /data/ls/label_studio.sqlite3 β†’ HF repo
  python3 db_sync.py watch     # backup every INTERVAL seconds (run in background)
"""

import os
import shutil
import sqlite3
import sys
import tempfile
import time

HF_TOKEN   = os.environ.get("HF_TOKEN", "")
BACKUP_REPO = "TrustLLMeu/saga-db-backup"
REMOTE_FILE = "label_studio.sqlite3"
INTERVAL   = 300   # backup every 5 minutes

# Find the real LS DB β€” the one with the most users.
# LS may store its DB at /data/label_studio.sqlite3 (not /data/ls/label_studio.sqlite3).
def _find_db_path():
    import glob as _g
    candidates = (
        _g.glob("/data/**/*.sqlite3", recursive=True) +
        _g.glob("/label-studio/**/*.sqlite3", recursive=True)
    )
    best, best_n = "/data/ls/label_studio.sqlite3", -1
    for p in candidates:
        try:
            c = sqlite3.connect(p)
            n = c.execute("SELECT COUNT(*) FROM htx_user").fetchone()[0]
            c.close()
            if n > best_n:
                best, best_n = p, n
        except Exception:
            pass
    print(f"[db_sync] DB path: {best} ({best_n} users)", flush=True)
    return best

DB_PATH = _find_db_path()


def _api():
    from huggingface_hub import HfApi
    if not HF_TOKEN:
        raise RuntimeError("HF_TOKEN env var not set")
    return HfApi(token=HF_TOKEN)


def _safe_copy(src, dst):
    """Copy a SQLite DB using the backup API so WAL data is included."""
    with sqlite3.connect(src) as src_conn:
        with sqlite3.connect(dst) as dst_conn:
            src_conn.backup(dst_conn)


def restore():
    """Download DB from HF backup repo if it exists. Returns True if restored."""
    try:
        api = _api()
        # Check if backup file exists in repo
        files = api.list_repo_files(BACKUP_REPO, repo_type="dataset")
        if REMOTE_FILE not in list(files):
            print(f"[db_sync] No backup found in {BACKUP_REPO} β€” fresh start.", flush=True)
            return False

        print(f"[db_sync] Restoring DB from {BACKUP_REPO}...", flush=True)
        os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)

        # Download to a temp file first, then atomically replace
        tmp = DB_PATH + ".restore_tmp"
        path = api.hf_hub_download(
            repo_id=BACKUP_REPO,
            filename=REMOTE_FILE,
            repo_type="dataset",
            local_dir=os.path.dirname(tmp),
            local_dir_use_symlinks=False,
        )
        shutil.move(path, DB_PATH)
        size = os.path.getsize(DB_PATH)
        print(f"[db_sync] Restored DB ({size:,} bytes).", flush=True)
        return True
    except Exception as e:
        print(f"[db_sync] Restore failed: {e}", flush=True)
        return False


def backup():
    """Upload current DB to HF backup repo using SQLite backup API (handles WAL)."""
    if not os.path.exists(DB_PATH):
        print(f"[db_sync] No DB at {DB_PATH} β€” skipping backup.", flush=True)
        return False
    tmp = None
    try:
        api = _api()
        # Use SQLite backup API to create a consistent snapshot that includes WAL data.
        # Uploading the raw .sqlite3 file misses uncommitted WAL transactions.
        tmp = DB_PATH + ".upload_tmp"
        _safe_copy(DB_PATH, tmp)
        size = os.path.getsize(tmp)
        # Verify the copy has data
        with sqlite3.connect(tmp) as check:
            n = check.execute("SELECT COUNT(*) FROM htx_user").fetchone()[0]
        print(f"[db_sync] Backing up DB ({size:,} bytes, {n} users) β†’ {BACKUP_REPO}...", flush=True)
        api.upload_file(
            path_or_fileobj=tmp,
            path_in_repo=REMOTE_FILE,
            repo_id=BACKUP_REPO,
            repo_type="dataset",
            commit_message=f"Auto-backup from HF Space ({n} users)",
        )
        print(f"[db_sync] Backup complete ({n} users).", flush=True)
        return True
    except Exception as e:
        print(f"[db_sync] Backup failed: {e}", flush=True)
        return False
    finally:
        if tmp and os.path.exists(tmp):
            try:
                os.unlink(tmp)
            except Exception:
                pass


def watch():
    """Run backup every INTERVAL seconds."""
    print(f"[db_sync] Watch mode: backing up every {INTERVAL}s.", flush=True)
    while True:
        time.sleep(INTERVAL)
        backup()


if __name__ == "__main__":
    cmd = sys.argv[1] if len(sys.argv) > 1 else "backup"
    if cmd == "restore":
        ok = restore()
        sys.exit(0 if ok else 1)
    elif cmd == "backup":
        ok = backup()
        sys.exit(0 if ok else 1)
    elif cmd == "watch":
        watch()
    else:
        print(f"Usage: db_sync.py restore|backup|watch", flush=True)
        sys.exit(1)