File size: 5,162 Bytes
c001f24
 
 
cdd3652
e5657d7
 
d30571e
e5657d7
7a709ef
c001f24
 
 
 
e5657d7
 
b49d3e6
e5657d7
 
b49d3e6
 
 
 
e5657d7
b49d3e6
e5657d7
b49d3e6
 
e5657d7
b49d3e6
 
 
 
d30571e
 
 
 
 
b49d3e6
d30571e
 
e5657d7
d30571e
 
e5657d7
 
 
b49d3e6
 
d30571e
e5657d7
 
d30571e
cdd3652
c001f24
d30571e
e5657d7
d30571e
 
7a709ef
e5657d7
b49d3e6
 
 
 
d30571e
b49d3e6
 
 
 
a0c684f
 
b49d3e6
 
a0c684f
 
b49d3e6
d30571e
b49d3e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d30571e
b49d3e6
 
 
 
 
 
 
 
d30571e
b49d3e6
e5657d7
d30571e
c001f24
b49d3e6
 
 
 
 
 
 
 
c001f24
d30571e
c001f24
 
e5657d7
 
 
 
d30571e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import sys
import shutil
import sqlite3
import json
import time
import hashlib
from datetime import datetime
from huggingface_hub import snapshot_download, HfApi

# Configuration
REPO_ID = os.environ.get("DATASET_REPO_ID")
HF_TOKEN = os.environ.get("HF_TOKEN")
DATA_DIR = "data_repo"
DB_FILE = os.path.join(DATA_DIR, "database.db")
STATE_FILE = os.path.join(DATA_DIR, "sync_state.json")
LOCK_FILE = "/tmp/hf_sync.lock"

api = HfApi(token=HF_TOKEN)

def get_state():
    if os.path.exists(STATE_FILE):
        try:
            with open(STATE_FILE, 'r') as f:
                return json.load(f)
        except: pass
    return {"uploaded_files": {}, "last_db_hash": None, "version": 0}

def save_state(state):
    state["last_update"] = datetime.now().isoformat()
    with open(STATE_FILE, 'w') as f:
        json.dump(state, f, indent=2)

def get_file_hash(path):
    if not os.path.exists(path): return None
    hasher = hashlib.md5()
    with open(path, 'rb') as f:
        for chunk in iter(lambda: f.read(4096), b""): hasher.update(chunk)
    return hasher.hexdigest()

def safe_db_backup():
    if not os.path.exists(DB_FILE): return None
    backup_db = DB_FILE + ".bak"
    try:
        source_conn = sqlite3.connect(DB_FILE)
        dest_conn = sqlite3.connect(backup_db)
        with dest_conn: source_conn.backup(dest_conn)
        source_conn.close(); dest_conn.close()
        return backup_db
    except Exception as e:
        print(f"Database backup failed: {e}")
        return None

def upload():
    if not REPO_ID or not HF_TOKEN: return
    if os.path.exists(LOCK_FILE):
        if time.time() - os.path.getmtime(LOCK_FILE) < 600: return

    try:
        with open(LOCK_FILE, 'w') as f: f.write(str(os.getpid()))
        state = get_state()
        changes_made = False

        # 1. Sync Database (Granular)
        backup_path = safe_db_backup()
        if backup_path:
            db_hash = get_file_hash(backup_path)
            if db_hash != state.get("last_db_hash"):
                print("Syncing Database...")
                # Upload the backup file directly without replacing the active database
                api.upload_file(path_or_fileobj=backup_path, path_in_repo="database.db", repo_id=REPO_ID, repo_type="dataset")
                state["last_db_hash"] = db_hash
                changes_made = True
            # Clean up the backup file regardless
            if os.path.exists(backup_path):
                os.remove(backup_path)

        # 2. Sync Files Iteratively (Immune to folder timeouts)
        for sub_dir in ['uploads', 'processed', 'output']:
            dir_path = os.path.join(DATA_DIR, sub_dir)
            if not os.path.exists(dir_path): continue
            
            for root, _, files in os.walk(dir_path):
                for file in files:
                    full_path = os.path.join(root, file)
                    rel_path = os.path.relpath(full_path, DATA_DIR)
                    
                    # Check if file needs upload (by size/mtime to avoid hashing thousands of images)
                    mtime = os.path.getmtime(full_path)
                    size = os.path.getsize(full_path)
                    file_id = f"{rel_path}_{size}_{mtime}"
                    
                    if state["uploaded_files"].get(rel_path) != file_id:
                        print(f"Syncing new file: {rel_path}")
                        try:
                            api.upload_file(
                                path_or_fileobj=full_path,
                                path_in_repo=rel_path,
                                repo_id=REPO_ID,
                                repo_type="dataset"
                            )
                            state["uploaded_files"][rel_path] = file_id
                            changes_made = True
                        except Exception as e:
                            print(f"Failed to upload {rel_path}: {e}")

        if changes_made:
            state["version"] += 1
            save_state(state)
            # Sync state file too
            api.upload_file(path_or_fileobj=STATE_FILE, path_in_repo="sync_state.json", repo_id=REPO_ID, repo_type="dataset")
            print(f"Sync complete. Version {state['version']} saved.")
        else:
            print("Everything up to date.")

    except Exception as e: print(f"Upload process failed: {e}")
    finally:
        if os.path.exists(LOCK_FILE): os.remove(LOCK_FILE)

def download():
    if not REPO_ID: return
    print(f"Downloading data from {REPO_ID}...")
    try:
        snapshot_download(repo_id=REPO_ID, repo_type="dataset", local_dir=DATA_DIR, token=HF_TOKEN, max_workers=8)
        print("Download successful.")
    except Exception as e: print(f"Download failed: {e}")

def init_local():
    for d in ['output', 'processed', 'uploads']: os.makedirs(f"{DATA_DIR}/{d}", exist_ok=True)

if __name__ == "__main__":
    action = sys.argv[1] if len(sys.argv) > 1 else "help"
    if action == "download": download()
    elif action == "upload": upload()
    elif action == "init": init_local()
    else: print("Usage: python hf_sync.py [download|upload|init]")