import os import json import time import socket import threading import re import requests import pyarrow.parquet as pq import pyarrow as pa import gc from pathlib import Path from huggingface_hub import HfApi # ── Config ─────────────────────────────────────────────────────────────────── HF_TOKEN = os.environ.get("HF_TOKEN") DATASET_REPO = "HuggingFaceFW/fineweb-edu" RAW_DIR = "/data/raw" STATE_FILE = "/data/state.json" WORKER_TIMEOUT = 600 MAX_BUFFERED = 9999 CC_PREFIX = "data/CC-MAIN-2025-05" ROWS_PER_CHUNK = 50_000 os.makedirs(RAW_DIR, exist_ok=True) api = HfApi(token=HF_TOKEN) # ── Keep-alive ──────────────────────────────────────────────────────────────── def serve(): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.bind(("0.0.0.0", 7860)) s.listen(5) print("✓ Listening on port 7860") while True: conn, _ = s.accept() conn.send(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK") conn.close() # ── Friendly name ───────────────────────────────────────────────────────────── def friendly_name(hf_path): m = re.search(r"CC-MAIN-(\d{4}-\d+)/\d+_(\d+)\.parquet", hf_path) if m: return f"cc{m.group(1)}_{int(m.group(2)):06d}.parquet" return hf_path.replace("/", "__") # ── State ───────────────────────────────────────────────────────────────────── def load_state(): if os.path.exists(STATE_FILE): with open(STATE_FILE) as f: state = json.load(f) shards = state["shards"] queue = state.get("queue", []) done = sum(1 for v in shards.values() if v["status"] == "done") claimed = sum(1 for v in shards.values() if v["status"] == "claimed") pending = sum(1 for v in shards.values() if v["status"] == "pending") print(f"Resuming — {done} done / {claimed} claimed / {pending} buffered / {len(queue)} queued") else: state = {"shards": {}, "queue": []} print("Starting fresh") return state def save_state(state): tmp = STATE_FILE + ".tmp" with open(tmp, "w") as f: json.dump(state, f, indent=2) os.replace(tmp, STATE_FILE) # ── Discover ────────────────────────────────────────────────────────────────── def discover_queue(state): print("Discovering shards from HF...") files = api.list_repo_files(DATASET_REPO, repo_type="dataset") known = {v["hf_path"] for v in state["shards"].values()} | set(state.get("queue", [])) new_count = 0 for f in files: if f.startswith(CC_PREFIX) and f.endswith(".parquet") and f not in known: state["queue"].append(f) new_count += 1 print(f"✓ {new_count} queued | {len(state['queue'])} in queue | {len(state['shards'])} in state") save_state(state) # ── Reclaim timed-out shards ────────────────────────────────────────────────── def reclaim_stale(state): now = time.time() reclaimed = 0 for name, info in state["shards"].items(): if info["status"] == "claimed" and info["claimed_at"]: if now - info["claimed_at"] > WORKER_TIMEOUT: print(f" ⚠ Reclaiming: {name} (worker: {info['worker']})") info["status"] = "pending" info["worker"] = None info["claimed_at"] = None reclaimed += 1 if reclaimed: save_state(state) # ── Split parquet into chunks ───────────────────────────────────────────────── def split_parquet(src_path, name): pf = pq.ParquetFile(src_path) chunk_paths = [] chunk_idx = 0 current = [] for batch in pf.iter_batches(batch_size=10_000, columns=["text"]): current.append(batch) if sum(len(b) for b in current) >= ROWS_PER_CHUNK: chunk_name = name.replace(".parquet", f"_chunk{chunk_idx:03d}.parquet") chunk_path = Path(RAW_DIR) / chunk_name table = pa.Table.from_batches(current) pq.write_table(table, chunk_path) print(f" ✓ {chunk_name} ({len(table):,} rows)") chunk_paths.append(chunk_name) chunk_idx += 1 current = [] del table gc.collect() if current: chunk_name = name.replace(".parquet", f"_chunk{chunk_idx:03d}.parquet") chunk_path = Path(RAW_DIR) / chunk_name table = pa.Table.from_batches(current) pq.write_table(table, chunk_path) print(f" ✓ {chunk_name} ({len(table):,} rows)") chunk_paths.append(chunk_name) del table gc.collect() return chunk_paths # ── Download loop ───────────────────────────────────────────────────────────── def download_loop(state): base_url = f"https://huggingface.co/datasets/{DATASET_REPO}/resolve/main/" while True: try: with open(STATE_FILE) as f: fresh = json.load(f) state["shards"] = fresh["shards"] state["queue"] = fresh.get("queue", []) except Exception: pass reclaim_stale(state) buffered = sum(1 for v in state["shards"].values() if v["status"] == "pending") if buffered >= MAX_BUFFERED: time.sleep(30) continue if not state["queue"]: done = sum(1 for v in state["shards"].values() if v["status"] == "done") total = len(state["shards"]) if done == total and total > 0: print("✓ All shards complete!") break print(" Queue empty — sleeping...") time.sleep(60) continue hf_path = state["queue"][0] name = friendly_name(hf_path) raw_path = Path(RAW_DIR) / name tmp_path = Path(RAW_DIR) / f"{name}.tmp" url = base_url + hf_path print(f" Downloading: {hf_path} → {name}") try: resp = requests.get( url, headers={"Authorization": f"Bearer {HF_TOKEN}"}, timeout=300, stream=True, ) resp.raise_for_status() with open(tmp_path, "wb") as f: for chunk in resp.iter_content(chunk_size=8 * 1024 * 1024): f.write(chunk) tmp_path.rename(raw_path) except Exception as e: print(f" ✗ Download failed: {e} — retrying in 30s") tmp_path.unlink(missing_ok=True) time.sleep(30) continue print(f" Splitting: {name}") try: chunk_names = split_parquet(raw_path, name) except Exception as e: print(f" ✗ Split failed: {e} — retrying in 30s") raw_path.unlink(missing_ok=True) time.sleep(30) continue raw_path.unlink(missing_ok=True) state["queue"].pop(0) for chunk_name in chunk_names: state["shards"][chunk_name] = { "status": "pending", "hf_path": hf_path, "worker": None, "claimed_at": None, "error": None, } save_state(state) print(f" ✓ {len(chunk_names)} chunks ready from {name}") time.sleep(5) # ── Monitor ─────────────────────────────────────────────────────────────────── def monitor_loop(): while True: time.sleep(120) try: with open(STATE_FILE) as f: s = json.load(f) shards = s["shards"] queue = s.get("queue", []) done = sum(1 for v in shards.values() if v["status"] == "done") claimed = sum(1 for v in shards.values() if v["status"] == "claimed") pending = sum(1 for v in shards.values() if v["status"] == "pending") total = len(shards) + len(queue) pct = (done / total * 100) if total else 0 print(f"[MONITOR] {done}/{total} ({pct:.1f}%) | {claimed} active | {pending} buffered | {len(queue)} queued") except Exception: pass # ── Entry point ─────────────────────────────────────────────────────────────── if __name__ == "__main__": threading.Thread(target=serve, daemon=True).start() state = load_state() discover_queue(state) threading.Thread(target=monitor_loop, daemon=True).start() threading.Thread(target=download_loop, args=(state,), daemon=True).start() while True: time.sleep(60)