Tok-cor / app.py
Neon-tech's picture
Rename app.pyc to app.py
2f7c306 verified
import os
import json
import time
import socket
import threading
import re
import requests
import pyarrow.parquet as pq
import pyarrow as pa
import gc
from pathlib import Path
from huggingface_hub import HfApi
# ── Config ───────────────────────────────────────────────────────────────────
HF_TOKEN = os.environ.get("HF_TOKEN")
DATASET_REPO = "HuggingFaceFW/fineweb-edu"
RAW_DIR = "/data/raw"
STATE_FILE = "/data/state.json"
WORKER_TIMEOUT = 600
MAX_BUFFERED = 9999
CC_PREFIX = "data/CC-MAIN-2025-05"
ROWS_PER_CHUNK = 50_000
os.makedirs(RAW_DIR, exist_ok=True)
api = HfApi(token=HF_TOKEN)
# ── Keep-alive ────────────────────────────────────────────────────────────────
def serve():
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind(("0.0.0.0", 7860))
s.listen(5)
print("βœ“ Listening on port 7860")
while True:
conn, _ = s.accept()
conn.send(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK")
conn.close()
# ── Friendly name ─────────────────────────────────────────────────────────────
def friendly_name(hf_path):
m = re.search(r"CC-MAIN-(\d{4}-\d+)/\d+_(\d+)\.parquet", hf_path)
if m:
return f"cc{m.group(1)}_{int(m.group(2)):06d}.parquet"
return hf_path.replace("/", "__")
# ── State ─────────────────────────────────────────────────────────────────────
def load_state():
if os.path.exists(STATE_FILE):
with open(STATE_FILE) as f:
state = json.load(f)
shards = state["shards"]
queue = state.get("queue", [])
done = sum(1 for v in shards.values() if v["status"] == "done")
claimed = sum(1 for v in shards.values() if v["status"] == "claimed")
pending = sum(1 for v in shards.values() if v["status"] == "pending")
print(f"Resuming β€” {done} done / {claimed} claimed / {pending} buffered / {len(queue)} queued")
else:
state = {"shards": {}, "queue": []}
print("Starting fresh")
return state
def save_state(state):
tmp = STATE_FILE + ".tmp"
with open(tmp, "w") as f:
json.dump(state, f, indent=2)
os.replace(tmp, STATE_FILE)
# ── Discover ──────────────────────────────────────────────────────────────────
def discover_queue(state):
print("Discovering shards from HF...")
files = api.list_repo_files(DATASET_REPO, repo_type="dataset")
known = {v["hf_path"] for v in state["shards"].values()} | set(state.get("queue", []))
new_count = 0
for f in files:
if f.startswith(CC_PREFIX) and f.endswith(".parquet") and f not in known:
state["queue"].append(f)
new_count += 1
print(f"βœ“ {new_count} queued | {len(state['queue'])} in queue | {len(state['shards'])} in state")
save_state(state)
# ── Reclaim timed-out shards ──────────────────────────────────────────────────
def reclaim_stale(state):
now = time.time()
reclaimed = 0
for name, info in state["shards"].items():
if info["status"] == "claimed" and info["claimed_at"]:
if now - info["claimed_at"] > WORKER_TIMEOUT:
print(f" ⚠ Reclaiming: {name} (worker: {info['worker']})")
info["status"] = "pending"
info["worker"] = None
info["claimed_at"] = None
reclaimed += 1
if reclaimed:
save_state(state)
# ── Split parquet into chunks ─────────────────────────────────────────────────
def split_parquet(src_path, name):
pf = pq.ParquetFile(src_path)
chunk_paths = []
chunk_idx = 0
current = []
for batch in pf.iter_batches(batch_size=10_000, columns=["text"]):
current.append(batch)
if sum(len(b) for b in current) >= ROWS_PER_CHUNK:
chunk_name = name.replace(".parquet", f"_chunk{chunk_idx:03d}.parquet")
chunk_path = Path(RAW_DIR) / chunk_name
table = pa.Table.from_batches(current)
pq.write_table(table, chunk_path)
print(f" βœ“ {chunk_name} ({len(table):,} rows)")
chunk_paths.append(chunk_name)
chunk_idx += 1
current = []
del table
gc.collect()
if current:
chunk_name = name.replace(".parquet", f"_chunk{chunk_idx:03d}.parquet")
chunk_path = Path(RAW_DIR) / chunk_name
table = pa.Table.from_batches(current)
pq.write_table(table, chunk_path)
print(f" βœ“ {chunk_name} ({len(table):,} rows)")
chunk_paths.append(chunk_name)
del table
gc.collect()
return chunk_paths
# ── Download loop ─────────────────────────────────────────────────────────────
def download_loop(state):
base_url = f"https://huggingface.co/datasets/{DATASET_REPO}/resolve/main/"
while True:
try:
with open(STATE_FILE) as f:
fresh = json.load(f)
state["shards"] = fresh["shards"]
state["queue"] = fresh.get("queue", [])
except Exception:
pass
reclaim_stale(state)
buffered = sum(1 for v in state["shards"].values() if v["status"] == "pending")
if buffered >= MAX_BUFFERED:
time.sleep(30)
continue
if not state["queue"]:
done = sum(1 for v in state["shards"].values() if v["status"] == "done")
total = len(state["shards"])
if done == total and total > 0:
print("βœ“ All shards complete!")
break
print(" Queue empty β€” sleeping...")
time.sleep(60)
continue
hf_path = state["queue"][0]
name = friendly_name(hf_path)
raw_path = Path(RAW_DIR) / name
tmp_path = Path(RAW_DIR) / f"{name}.tmp"
url = base_url + hf_path
print(f" Downloading: {hf_path} β†’ {name}")
try:
resp = requests.get(
url,
headers={"Authorization": f"Bearer {HF_TOKEN}"},
timeout=300,
stream=True,
)
resp.raise_for_status()
with open(tmp_path, "wb") as f:
for chunk in resp.iter_content(chunk_size=8 * 1024 * 1024):
f.write(chunk)
tmp_path.rename(raw_path)
except Exception as e:
print(f" βœ— Download failed: {e} β€” retrying in 30s")
tmp_path.unlink(missing_ok=True)
time.sleep(30)
continue
print(f" Splitting: {name}")
try:
chunk_names = split_parquet(raw_path, name)
except Exception as e:
print(f" βœ— Split failed: {e} β€” retrying in 30s")
raw_path.unlink(missing_ok=True)
time.sleep(30)
continue
raw_path.unlink(missing_ok=True)
state["queue"].pop(0)
for chunk_name in chunk_names:
state["shards"][chunk_name] = {
"status": "pending",
"hf_path": hf_path,
"worker": None,
"claimed_at": None,
"error": None,
}
save_state(state)
print(f" βœ“ {len(chunk_names)} chunks ready from {name}")
time.sleep(5)
# ── Monitor ───────────────────────────────────────────────────────────────────
def monitor_loop():
while True:
time.sleep(120)
try:
with open(STATE_FILE) as f:
s = json.load(f)
shards = s["shards"]
queue = s.get("queue", [])
done = sum(1 for v in shards.values() if v["status"] == "done")
claimed = sum(1 for v in shards.values() if v["status"] == "claimed")
pending = sum(1 for v in shards.values() if v["status"] == "pending")
total = len(shards) + len(queue)
pct = (done / total * 100) if total else 0
print(f"[MONITOR] {done}/{total} ({pct:.1f}%) | {claimed} active | {pending} buffered | {len(queue)} queued")
except Exception:
pass
# ── Entry point ───────────────────────────────────────────────────────────────
if __name__ == "__main__":
threading.Thread(target=serve, daemon=True).start()
state = load_state()
discover_queue(state)
threading.Thread(target=monitor_loop, daemon=True).start()
threading.Thread(target=download_loop, args=(state,), daemon=True).start()
while True:
time.sleep(60)