| import os |
| import json |
| import time |
| import socket |
| import threading |
| import re |
| import requests |
| import pyarrow.parquet as pq |
| import pyarrow as pa |
| import gc |
| from pathlib import Path |
| from huggingface_hub import HfApi |
|
|
| |
| HF_TOKEN = os.environ.get("HF_TOKEN") |
| DATASET_REPO = "HuggingFaceFW/fineweb-edu" |
| RAW_DIR = "/data/raw" |
| STATE_FILE = "/data/state.json" |
| WORKER_TIMEOUT = 600 |
| MAX_BUFFERED = 9999 |
| CC_PREFIX = "data/CC-MAIN-2025-05" |
| ROWS_PER_CHUNK = 50_000 |
|
|
| os.makedirs(RAW_DIR, exist_ok=True) |
| api = HfApi(token=HF_TOKEN) |
|
|
| |
| def serve(): |
| s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
| s.bind(("0.0.0.0", 7860)) |
| s.listen(5) |
| print("β Listening on port 7860") |
| while True: |
| conn, _ = s.accept() |
| conn.send(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK") |
| conn.close() |
|
|
| |
| def friendly_name(hf_path): |
| m = re.search(r"CC-MAIN-(\d{4}-\d+)/\d+_(\d+)\.parquet", hf_path) |
| if m: |
| return f"cc{m.group(1)}_{int(m.group(2)):06d}.parquet" |
| return hf_path.replace("/", "__") |
|
|
| |
| def load_state(): |
| if os.path.exists(STATE_FILE): |
| with open(STATE_FILE) as f: |
| state = json.load(f) |
| shards = state["shards"] |
| queue = state.get("queue", []) |
| done = sum(1 for v in shards.values() if v["status"] == "done") |
| claimed = sum(1 for v in shards.values() if v["status"] == "claimed") |
| pending = sum(1 for v in shards.values() if v["status"] == "pending") |
| print(f"Resuming β {done} done / {claimed} claimed / {pending} buffered / {len(queue)} queued") |
| else: |
| state = {"shards": {}, "queue": []} |
| print("Starting fresh") |
| return state |
|
|
| def save_state(state): |
| tmp = STATE_FILE + ".tmp" |
| with open(tmp, "w") as f: |
| json.dump(state, f, indent=2) |
| os.replace(tmp, STATE_FILE) |
|
|
| |
| def discover_queue(state): |
| print("Discovering shards from HF...") |
| files = api.list_repo_files(DATASET_REPO, repo_type="dataset") |
| known = {v["hf_path"] for v in state["shards"].values()} | set(state.get("queue", [])) |
| new_count = 0 |
| for f in files: |
| if f.startswith(CC_PREFIX) and f.endswith(".parquet") and f not in known: |
| state["queue"].append(f) |
| new_count += 1 |
| print(f"β {new_count} queued | {len(state['queue'])} in queue | {len(state['shards'])} in state") |
| save_state(state) |
|
|
| |
| def reclaim_stale(state): |
| now = time.time() |
| reclaimed = 0 |
| for name, info in state["shards"].items(): |
| if info["status"] == "claimed" and info["claimed_at"]: |
| if now - info["claimed_at"] > WORKER_TIMEOUT: |
| print(f" β Reclaiming: {name} (worker: {info['worker']})") |
| info["status"] = "pending" |
| info["worker"] = None |
| info["claimed_at"] = None |
| reclaimed += 1 |
| if reclaimed: |
| save_state(state) |
|
|
| |
| def split_parquet(src_path, name): |
| pf = pq.ParquetFile(src_path) |
| chunk_paths = [] |
| chunk_idx = 0 |
| current = [] |
|
|
| for batch in pf.iter_batches(batch_size=10_000, columns=["text"]): |
| current.append(batch) |
| if sum(len(b) for b in current) >= ROWS_PER_CHUNK: |
| chunk_name = name.replace(".parquet", f"_chunk{chunk_idx:03d}.parquet") |
| chunk_path = Path(RAW_DIR) / chunk_name |
| table = pa.Table.from_batches(current) |
| pq.write_table(table, chunk_path) |
| print(f" β {chunk_name} ({len(table):,} rows)") |
| chunk_paths.append(chunk_name) |
| chunk_idx += 1 |
| current = [] |
| del table |
| gc.collect() |
|
|
| if current: |
| chunk_name = name.replace(".parquet", f"_chunk{chunk_idx:03d}.parquet") |
| chunk_path = Path(RAW_DIR) / chunk_name |
| table = pa.Table.from_batches(current) |
| pq.write_table(table, chunk_path) |
| print(f" β {chunk_name} ({len(table):,} rows)") |
| chunk_paths.append(chunk_name) |
| del table |
| gc.collect() |
|
|
| return chunk_paths |
|
|
| |
| def download_loop(state): |
| base_url = f"https://huggingface.co/datasets/{DATASET_REPO}/resolve/main/" |
|
|
| while True: |
| try: |
| with open(STATE_FILE) as f: |
| fresh = json.load(f) |
| state["shards"] = fresh["shards"] |
| state["queue"] = fresh.get("queue", []) |
| except Exception: |
| pass |
|
|
| reclaim_stale(state) |
|
|
| buffered = sum(1 for v in state["shards"].values() if v["status"] == "pending") |
| if buffered >= MAX_BUFFERED: |
| time.sleep(30) |
| continue |
|
|
| if not state["queue"]: |
| done = sum(1 for v in state["shards"].values() if v["status"] == "done") |
| total = len(state["shards"]) |
| if done == total and total > 0: |
| print("β All shards complete!") |
| break |
| print(" Queue empty β sleeping...") |
| time.sleep(60) |
| continue |
|
|
| hf_path = state["queue"][0] |
| name = friendly_name(hf_path) |
| raw_path = Path(RAW_DIR) / name |
| tmp_path = Path(RAW_DIR) / f"{name}.tmp" |
| url = base_url + hf_path |
|
|
| print(f" Downloading: {hf_path} β {name}") |
| try: |
| resp = requests.get( |
| url, |
| headers={"Authorization": f"Bearer {HF_TOKEN}"}, |
| timeout=300, |
| stream=True, |
| ) |
| resp.raise_for_status() |
| with open(tmp_path, "wb") as f: |
| for chunk in resp.iter_content(chunk_size=8 * 1024 * 1024): |
| f.write(chunk) |
| tmp_path.rename(raw_path) |
| except Exception as e: |
| print(f" β Download failed: {e} β retrying in 30s") |
| tmp_path.unlink(missing_ok=True) |
| time.sleep(30) |
| continue |
|
|
| print(f" Splitting: {name}") |
| try: |
| chunk_names = split_parquet(raw_path, name) |
| except Exception as e: |
| print(f" β Split failed: {e} β retrying in 30s") |
| raw_path.unlink(missing_ok=True) |
| time.sleep(30) |
| continue |
|
|
| raw_path.unlink(missing_ok=True) |
|
|
| state["queue"].pop(0) |
| for chunk_name in chunk_names: |
| state["shards"][chunk_name] = { |
| "status": "pending", |
| "hf_path": hf_path, |
| "worker": None, |
| "claimed_at": None, |
| "error": None, |
| } |
| save_state(state) |
| print(f" β {len(chunk_names)} chunks ready from {name}") |
| time.sleep(5) |
|
|
| |
| def monitor_loop(): |
| while True: |
| time.sleep(120) |
| try: |
| with open(STATE_FILE) as f: |
| s = json.load(f) |
| shards = s["shards"] |
| queue = s.get("queue", []) |
| done = sum(1 for v in shards.values() if v["status"] == "done") |
| claimed = sum(1 for v in shards.values() if v["status"] == "claimed") |
| pending = sum(1 for v in shards.values() if v["status"] == "pending") |
| total = len(shards) + len(queue) |
| pct = (done / total * 100) if total else 0 |
| print(f"[MONITOR] {done}/{total} ({pct:.1f}%) | {claimed} active | {pending} buffered | {len(queue)} queued") |
| except Exception: |
| pass |
|
|
| |
| if __name__ == "__main__": |
| threading.Thread(target=serve, daemon=True).start() |
| state = load_state() |
| discover_queue(state) |
| threading.Thread(target=monitor_loop, daemon=True).start() |
| threading.Thread(target=download_loop, args=(state,), daemon=True).start() |
| while True: |
| time.sleep(60) |