Cor / app.py
Neon-tech's picture
Update app.py
528f3a5 verified
Raw
History Blame Contribute Delete
11.7 kB
import os
import json
import time
import socket
import threading
import requests
import pyarrow.parquet as pq
import gc
from pathlib import Path
from huggingface_hub import HfApi
# ── Config ───────────────────────────────────────────────────────────────────
HF_TOKEN = os.environ.get("HF_TOKEN")
RAW_DIR = "/data/raw"
STATE_FILE = "/data/state.json"
WORKER_TIMEOUT = 700
MAX_BUFFERED = 999999
os.makedirs(RAW_DIR, exist_ok=True)
api = HfApi(token=HF_TOKEN)
AUTH_HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"}
# ── Sources ───────────────────────────────────────────────────────────────────
SOURCES = [
{
"name" : "fineweb",
"type" : "hf_list",
"repo" : "HuggingFaceFW/fineweb-edu",
"prefix" : "data/CC-MAIN-2025-26",
"skip" : 5,
"take" : 10,
"text_col": "text",
},
{
"name" : "wikipedia",
"type" : "hf_list",
"repo" : "wikimedia/wikipedia",
"prefix" : "20231101.en/train-",
"skip" : 2,
"take" : 18,
"text_col": "text",
},
{
"name" : "openwebmath",
"type" : "hf_list",
"repo" : "open-web-math/open-web-math",
"prefix" : "data/train-",
"skip" : 0,
"take" : 6,
"text_col": "text",
},
{
"name" : "code",
"type" : "url_list",
"text_col": "text",
"fmt" : "jsonl",
"urls" : [
f"https://huggingface.co/buckets/Neon-tech/Dataset-arranger/resolve/by-language/{lang}/shard_{str(i).zfill(6)}.jsonl?download=true"
for lang in ["C", "C++", "Java", "Go", "Rust", "Ruby", "PHP", "SQL", "C#", "Scala", "Lua", "Perl", "CSS"]
for i in range(2)
],
},
]
# ── Keep-alive ────────────────────────────────────────────────────────────────
def serve():
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind(("0.0.0.0", 7860))
s.listen(5)
print("βœ“ Listening on port 7860")
while True:
conn, _ = s.accept()
conn.send(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK")
conn.close()
# ── State ─────────────────────────────────────────────────────────────────────
def load_state():
if os.path.exists(STATE_FILE):
with open(STATE_FILE) as f:
state = json.load(f)
shards = state["shards"]
queue = state.get("queue", [])
done = sum(1 for v in shards.values() if v["status"] == "done")
claimed = sum(1 for v in shards.values() if v["status"] == "claimed")
pending = sum(1 for v in shards.values() if v["status"] == "pending")
print(f"Resuming β€” {done} done / {claimed} claimed / {pending} buffered / {len(queue)} queued")
else:
state = {"shards": {}, "queue": []}
print("Starting fresh")
return state
def save_state(state):
tmp = STATE_FILE + ".tmp"
with open(tmp, "w") as f:
json.dump(state, f, indent=2)
os.replace(tmp, STATE_FILE)
# ── Discover ──────────────────────────────────────────────────────────────────
def discover_all(state):
known_urls = {v["url"] for v in state["shards"].values()} | {e["url"] for e in state.get("queue", [])}
new_count = 0
for src in SOURCES:
name = src["name"]
print(f"\nDiscovering: {name}")
if src["type"] == "hf_list":
all_files = sorted([
f for f in api.list_repo_files(src["repo"], repo_type="dataset")
if f.startswith(src["prefix"]) and f.endswith(".parquet")
])
selected = all_files[src["skip"]: src["skip"] + src["take"]]
base_url = f"https://huggingface.co/datasets/{src['repo']}/resolve/main/"
urls = [base_url + f for f in selected]
fmt = "parquet"
else:
urls = src["urls"]
fmt = src.get("fmt", "parquet")
added = 0
for url in urls:
if url not in known_urls:
state["queue"].append({
"url" : url,
"source" : name,
"text_col" : src["text_col"],
"fmt" : fmt,
})
known_urls.add(url)
new_count += 1
added += 1
print(f" {name}: {len(urls)} files | {added} new added to queue")
save_state(state)
print(f"\nTotal queued: {len(state['queue'])} | In state: {len(state['shards'])}")
# ── Reclaim stale ─────────────────────────────────────────────────────────────
def reclaim_stale(state):
now = time.time()
reclaimed = 0
for name, info in state["shards"].items():
if info["status"] == "claimed" and info.get("claimed_at"):
if now - info["claimed_at"] > WORKER_TIMEOUT:
print(f" ⚠ Reclaiming: {name}")
info["status"] = "pending"
info["worker"] = None
info["claimed_at"] = None
reclaimed += 1
if reclaimed:
save_state(state)
# ── Parquet β†’ JSONL ───────────────────────────────────────────────────────────
def parquet_to_jsonl(parquet_path, jsonl_path, text_col):
"""Stream parquet batch by batch β†’ write one JSON line per doc. No full load."""
pf = pq.ParquetFile(parquet_path)
n_written = 0
with open(jsonl_path, "w", encoding="utf-8") as out:
for batch in pf.iter_batches(batch_size=1_000, columns=[text_col]):
texts = batch.column(text_col).to_pylist()
for text in texts:
if text and isinstance(text, str) and text.strip():
out.write(json.dumps({"text": text.strip()}, ensure_ascii=False) + "\n")
n_written += 1
del texts
gc.collect()
return n_written
# ── Download loop ─────────────────────────────────────────────────────────────
def download_loop(state):
while True:
try:
with open(STATE_FILE) as f:
fresh = json.load(f)
state["shards"] = fresh["shards"]
state["queue"] = fresh.get("queue", [])
except Exception:
pass
reclaim_stale(state)
buffered = sum(1 for v in state["shards"].values() if v["status"] == "pending")
if buffered >= MAX_BUFFERED:
time.sleep(30)
continue
if not state["queue"]:
done = sum(1 for v in state["shards"].values() if v["status"] == "done")
total = len(state["shards"])
if done == total and total > 0:
print("βœ“ All shards complete!")
break
print(" Queue empty β€” sleeping...")
time.sleep(60)
continue
entry = state["queue"][0]
url = entry["url"]
source = entry["source"]
text_col = entry["text_col"]
fmt = entry.get("fmt", "parquet")
lang = url.split("?")[0].split("/")[-2]
base_name = url.split("?")[0].split("/")[-1].replace(".parquet", "").replace(".jsonl", "")
shard_name = f"{source}__{base_name}_{lang}.jsonl"
jsonl_path = Path(RAW_DIR) / shard_name
tmp_path = Path(RAW_DIR) / f"{shard_name}.tmp"
print(f" Downloading: {source} | {base_name}")
try:
resp = requests.get(url, headers=AUTH_HEADERS, timeout=300, stream=True)
resp.raise_for_status()
with open(tmp_path, "wb") as f:
for chunk in resp.iter_content(chunk_size=8 * 1024 * 1024):
f.write(chunk)
except Exception as e:
print(f" βœ— Download failed: {e} β€” retrying in 30s")
tmp_path.unlink(missing_ok=True)
time.sleep(30)
continue
if fmt == "parquet":
print(f" Converting β†’ jsonl: {shard_name}")
try:
n = parquet_to_jsonl(tmp_path, jsonl_path, text_col)
tmp_path.unlink(missing_ok=True)
print(f" βœ“ {n:,} docs")
except Exception as e:
print(f" βœ— Convert failed: {e}")
tmp_path.unlink(missing_ok=True)
jsonl_path.unlink(missing_ok=True)
time.sleep(30)
continue
else:
tmp_path.rename(jsonl_path)
state["queue"].pop(0)
state["shards"][shard_name] = {
"status" : "pending",
"url" : url,
"source" : source,
"worker" : None,
"claimed_at": None,
"error" : None,
}
save_state(state)
print(f" βœ“ Ready: {shard_name}")
time.sleep(3)
# ── Monitor ───────────────────────────────────────────────────────────────────
def monitor_loop():
while True:
time.sleep(120)
try:
with open(STATE_FILE) as f:
s = json.load(f)
shards = s["shards"]
queue = s.get("queue", [])
done = sum(1 for v in shards.values() if v["status"] == "done")
claimed = sum(1 for v in shards.values() if v["status"] == "claimed")
pending = sum(1 for v in shards.values() if v["status"] == "pending")
total = len(shards) + len(queue)
pct = (done / total * 100) if total else 0
src_done = {}
for v in shards.values():
src = v.get("source", "?")
if v["status"] == "done":
src_done[src] = src_done.get(src, 0) + 1
print(f"[MONITOR] {done}/{total} ({pct:.1f}%) | {claimed} active | {pending} buffered | {len(queue)} queued")
for src, cnt in sorted(src_done.items()):
print(f" {src}: {cnt} done")
except Exception:
pass
# ── Entry point ───────────────────────────────────────────────────────────────
if __name__ == "__main__":
threading.Thread(target=serve, daemon=True).start()
state = load_state()
discover_all(state)
threading.Thread(target=monitor_loop, daemon=True).start()
threading.Thread(target=download_loop, args=(state,), daemon=True).start()
while True:
time.sleep(60)