Neon-tech commited on
Commit
a29dc9c
Β·
verified Β·
1 Parent(s): 8d0f444

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +184 -0
app.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ import socket
5
+ import threading
6
+ import io
7
+ import requests
8
+ from pathlib import Path
9
+ from huggingface_hub import HfApi, list_repo_tree
10
+
11
+ # ── Config ───────────────────────────────────────────────────────────────────
12
+ HF_TOKEN = os.environ.get("HF_TOKEN")
13
+ DATASET_REPO = "HuggingFaceFW/fineweb-edu"
14
+ RAW_DIR = "/data/raw"
15
+ STATE_FILE = "/data/state.json"
16
+ WORKER_TIMEOUT = 600 # 10 min β€” reclaim stale claimed shards
17
+
18
+ # CC-MAIN-2025 prefix filter
19
+ CC_PREFIX = "data/CC-MAIN-2025"
20
+
21
+ os.makedirs(RAW_DIR, exist_ok=True)
22
+
23
+ api = HfApi(token=HF_TOKEN)
24
+
25
+ # ── Keep-alive ────────────────────────────────────────────────────────────────
26
+ def serve():
27
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
28
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
29
+ s.bind(("0.0.0.0", 7860))
30
+ s.listen(5)
31
+ print("βœ“ Listening on port 7860")
32
+ while True:
33
+ conn, _ = s.accept()
34
+ conn.send(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK")
35
+ conn.close()
36
+
37
+ # ── State ─────────────────────────────────────────────────────────────────────
38
+ def load_state():
39
+ if os.path.exists(STATE_FILE):
40
+ with open(STATE_FILE) as f:
41
+ state = json.load(f)
42
+ total = len(state["shards"])
43
+ done = sum(1 for s in state["shards"].values() if s["status"] == "done")
44
+ claimed = sum(1 for s in state["shards"].values() if s["status"] == "claimed")
45
+ pending = sum(1 for s in state["shards"].values() if s["status"] == "pending")
46
+ print(f"Resuming β€” {done} done / {claimed} claimed / {pending} pending / {total} total")
47
+ else:
48
+ state = {"shards": {}}
49
+ print("Starting fresh")
50
+ return state
51
+
52
+ def save_state(state):
53
+ tmp = STATE_FILE + ".tmp"
54
+ with open(tmp, "w") as f:
55
+ json.dump(state, f, indent=2)
56
+ os.replace(tmp, STATE_FILE)
57
+
58
+ # ── Discover all CC-MAIN-2025 parquet files ───────────────────────────────────
59
+ def discover_shards(state):
60
+ print("Discovering shards from HF...")
61
+ files = api.list_repo_files(DATASET_REPO, repo_type="dataset")
62
+ new_count = 0
63
+ for f in files:
64
+ if f.startswith(CC_PREFIX) and f.endswith(".parquet"):
65
+ if f not in state["shards"]:
66
+ state["shards"][f] = {
67
+ "status": "pending",
68
+ "worker": None,
69
+ "claimed_at": None,
70
+ }
71
+ new_count += 1
72
+ print(f"βœ“ {new_count} new shards discovered | {len(state['shards'])} total")
73
+ save_state(state)
74
+
75
+ # ── Reclaim timed-out shards ──────────────────────────────────────────────────
76
+ def reclaim_stale(state):
77
+ now = time.time()
78
+ reclaimed = 0
79
+ for shard, info in state["shards"].items():
80
+ if info["status"] == "claimed" and info["claimed_at"]:
81
+ if now - info["claimed_at"] > WORKER_TIMEOUT:
82
+ print(f" ⚠ Reclaiming stale shard: {shard} (worker: {info['worker']})")
83
+ info["status"] = "pending"
84
+ info["worker"] = None
85
+ info["claimed_at"] = None
86
+ reclaimed += 1
87
+ if reclaimed:
88
+ save_state(state)
89
+ return reclaimed
90
+
91
+ # ── Download pending shards to /data/raw ─────────────────────────────────────
92
+ def download_loop(state):
93
+ base_url = f"https://huggingface.co/datasets/{DATASET_REPO}/resolve/main/"
94
+
95
+ while True:
96
+ # Reclaim stale first
97
+ reclaim_stale(state)
98
+
99
+ # Reload state to pick up worker updates
100
+ if os.path.exists(STATE_FILE):
101
+ with open(STATE_FILE) as f:
102
+ state["shards"] = json.load(f)["shards"]
103
+
104
+ # Count how many raw files already sitting in /data/raw (not yet claimed)
105
+ raw_files = list(Path(RAW_DIR).glob("*.parquet"))
106
+ pending_raw = len(raw_files)
107
+
108
+ # Keep at most 4 shards pre-downloaded to avoid filling disk
109
+ if pending_raw >= 4:
110
+ print(f" Buffer full ({pending_raw} shards waiting) β€” sleeping...")
111
+ time.sleep(60)
112
+ continue
113
+
114
+ # Find next pending shard to download
115
+ to_download = None
116
+ for shard, info in state["shards"].items():
117
+ if info["status"] == "pending":
118
+ raw_name = shard.replace("/", "__") + ".parquet"
119
+ raw_path = Path(RAW_DIR) / raw_name
120
+ if not raw_path.exists():
121
+ to_download = shard
122
+ break
123
+
124
+ if not to_download:
125
+ done = sum(1 for s in state["shards"].values() if s["status"] == "done")
126
+ total = len(state["shards"])
127
+ if done == total:
128
+ print("βœ“ All shards complete!")
129
+ break
130
+ print(" Nothing to download right now β€” sleeping...")
131
+ time.sleep(60)
132
+ continue
133
+
134
+ # Download it
135
+ url = base_url + to_download
136
+ raw_name = to_download.replace("/", "__") + ".parquet"
137
+ raw_path = Path(RAW_DIR) / raw_name
138
+
139
+ print(f" Downloading: {to_download}")
140
+ try:
141
+ resp = requests.get(
142
+ url,
143
+ headers={"Authorization": f"Bearer {HF_TOKEN}"},
144
+ timeout=300,
145
+ stream=True,
146
+ )
147
+ resp.raise_for_status()
148
+ with open(raw_path, "wb") as f:
149
+ for chunk in resp.iter_content(chunk_size=8 * 1024 * 1024):
150
+ f.write(chunk)
151
+ print(f" βœ“ Downloaded: {raw_name}")
152
+ except Exception as e:
153
+ print(f" βœ— Download failed: {e}")
154
+ time.sleep(30)
155
+ continue
156
+
157
+ time.sleep(5)
158
+
159
+ # ── Monitor loop β€” prints progress ───────────────────────────────────────────
160
+ def monitor_loop(state):
161
+ while True:
162
+ time.sleep(120)
163
+ if os.path.exists(STATE_FILE):
164
+ with open(STATE_FILE) as f:
165
+ s = json.load(f)["shards"]
166
+ done = sum(1 for v in s.values() if v["status"] == "done")
167
+ claimed = sum(1 for v in s.values() if v["status"] == "claimed")
168
+ pending = sum(1 for v in s.values() if v["status"] == "pending")
169
+ total = len(s)
170
+ pct = (done / total * 100) if total else 0
171
+ print(f"[MONITOR] {done}/{total} done ({pct:.1f}%) | {claimed} active | {pending} pending")
172
+
173
+ # ── Entry point ───────────────────────────────────────────────────────────────
174
+ if __name__ == "__main__":
175
+ threading.Thread(target=serve, daemon=True).start()
176
+
177
+ state = load_state()
178
+ discover_shards(state)
179
+
180
+ threading.Thread(target=monitor_loop, args=(state,), daemon=True).start()
181
+ threading.Thread(target=download_loop, args=(state,), daemon=True).start()
182
+
183
+ while True:
184
+ time.sleep(60)