import os import json import time import socket import threading import gc import ctypes import multiprocessing as mp from pathlib import Path import numpy as np import pyarrow.parquet as pq from tokenizers import Tokenizer # ── Config ─────────────────────────────────────────────────────────────────── HF_TOKEN = os.environ.get("HF_TOKEN") STATE_FILE = "/data/state.json" RAW_DIR = "/data/raw" OUT_DIR = "/data/tokenized" TOK_PATH = "/data/tokenizer.json" WORKER_ID = socket.gethostname() POLL_INTERVAL = 15 os.makedirs(OUT_DIR, exist_ok=True) # ── Keep-alive ──────────────────────────────────────────────────────────────── def serve(): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.bind(("0.0.0.0", 7860)) s.listen(5) print(f"✓ [{WORKER_ID}] Listening on port 7860") while True: conn, _ = s.accept() conn.send(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK") conn.close() # ── State ───────────────────────────────────────────────────────────────────── def load_state(): with open(STATE_FILE) as f: return json.load(f) def save_state(state): tmp = STATE_FILE + f".tmp.{WORKER_ID}" with open(tmp, "w") as f: json.dump(state, f, indent=2) os.replace(tmp, STATE_FILE) # ── Claim a pending shard ───────────────────────────────────────────────────── def claim_shard(state): for name, info in state["shards"].items(): if info["status"] == "pending": raw_path = Path(RAW_DIR) / name if raw_path.exists(): info["status"] = "claimed" info["worker"] = WORKER_ID info["claimed_at"] = time.time() save_state(state) return name, raw_path return None, None # ── Tokenizer subprocess ────────────────────────────────────────────────────── _worker_tokenizer = None def init_worker(tok_path): global _worker_tokenizer _worker_tokenizer = Tokenizer.from_file(tok_path) def tokenize_chunk(texts): encs = _worker_tokenizer.encode_batch(texts) return [e.ids for e in encs if len(e.ids) >= 2] # ── Process shard ───────────────────────────────────────────────────────────── def process_shard(name, raw_path, pool): print(f" [{WORKER_ID}] Processing: {name}") out_name = name.replace(".parquet", ".bin") out_path = Path(OUT_DIR) / out_name tmp_path = Path(OUT_DIR) / f"{out_name}.tmp" total_tokens = 0 try: pf = pq.ParquetFile(raw_path) except Exception as e: raw_path.unlink(missing_ok=True) return False, f"read_failed: {e}" try: with open(tmp_path, "wb") as f: for batch in pf.iter_batches(batch_size=5_000, columns=["text"]): texts = batch.column("text").to_pylist() mid = len(texts) // 2 try: results = pool.map(tokenize_chunk, [texts[:mid], texts[mid:]]) except Exception as e: tmp_path.unlink(missing_ok=True) return False, f"tokenize_failed: {e}" for ids in results[0] + results[1]: arr = np.array(ids, dtype=np.uint16) arr.tofile(f) total_tokens += len(ids) del texts, results gc.collect() except Exception as e: tmp_path.unlink(missing_ok=True) return False, f"write_failed: {e}" tmp_path.rename(out_path) # ← atomic, only visible when complete print(f" ✓ [{WORKER_ID}] {out_name} | {total_tokens:,} tokens") return True, None # ── Force full memory flush ─────────────────────────────────────────────────── def flush_memory(): gc.collect() try: ctypes.CDLL("libc.so.6").malloc_trim(0) except Exception: pass # ── Worker loop ─────────────────────────────────────────────────────────────── def worker_loop(): os.makedirs(OUT_DIR, exist_ok=True) print(f"✓ [{WORKER_ID}] Loading tokenizer...") tok = Tokenizer.from_file(TOK_PATH) print(f"✓ [{WORKER_ID}] Tokenizer ready | vocab: {tok.get_vocab_size():,}") del tok flush_memory() pool = mp.Pool(processes=2, initializer=init_worker, initargs=(TOK_PATH,)) print(f"✓ [{WORKER_ID}] Worker pool ready") try: while True: if not os.path.exists(STATE_FILE): print(f" [{WORKER_ID}] Waiting for state.json...") time.sleep(POLL_INTERVAL) continue try: state = load_state() except Exception as e: print(f" [{WORKER_ID}] State read error: {e}") time.sleep(POLL_INTERVAL) continue total = len(state["shards"]) + len(state.get("queue", [])) done = sum(1 for v in state["shards"].values() if v["status"] == "done") if total > 0 and done == total: print(f" [{WORKER_ID}] All done. Sleeping.") time.sleep(300) continue name, raw_path = claim_shard(state) if not name: print(f" [{WORKER_ID}] Nothing ready — polling in {POLL_INTERVAL}s") time.sleep(POLL_INTERVAL) continue print(f" [{WORKER_ID}] Claimed: {name}") success, error = process_shard(name, raw_path, pool) try: state = load_state() except Exception: pass if success: state["shards"][name]["status"] = "done" state["shards"][name]["error"] = None save_state(state) try: raw_path.unlink() print(f" [{WORKER_ID}] Deleted: {raw_path.name}") except Exception as e: print(f" [{WORKER_ID}] Delete failed: {e}") else: state["shards"][name]["status"] = "pending" state["shards"][name]["worker"] = None state["shards"][name]["claimed_at"] = None state["shards"][name]["error"] = error save_state(state) print(f" [{WORKER_ID}] Failed ({error}) — left on disk for retry: {name}") flush_memory() time.sleep(5) finally: pool.terminate() pool.join() # ── Entry point ─────────────────────────────────────────────────────────────── if __name__ == "__main__": threading.Thread(target=serve, daemon=True).start() threading.Thread(target=worker_loop, daemon=True).start() while True: time.sleep(60)