import os import json import time import socket import threading import gc import ctypes import multiprocessing as mp from pathlib import Path import numpy as np import pyarrow.parquet as pq from tokenizers import Tokenizer # ── Config ─────────────────────────────────────────────────────────────────── HF_TOKEN = os.environ.get("HF_TOKEN") STATE_FILE = "/data/state.json" RAW_DIR = "/data/raw" OUT_DIR = "/data/tokenized" TOK_PATH = "/data/tokenizer.json" WORKER_ID = socket.gethostname() POLL_INTERVAL = 15 SCAN_INTERVAL = 600 # ── Keep-alive ──────────────────────────────────────────────────────────────── def serve(): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.bind(("0.0.0.0", 7860)) s.listen(5) print(f"✓ [{WORKER_ID}] Listening on port 7860") while True: conn, _ = s.accept() conn.send(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK") conn.close() # ── State ───────────────────────────────────────────────────────────────────── def load_state(): with open(STATE_FILE) as f: return json.load(f) def save_state(state): tmp = STATE_FILE + f".tmp.{WORKER_ID}" with open(tmp, "w") as f: json.dump(state, f, indent=2) os.replace(tmp, STATE_FILE) # ── Tokenizer subprocess ────────────────────────────────────────────────────── _worker_tokenizer = None def init_worker(tok_path): global _worker_tokenizer _worker_tokenizer = Tokenizer.from_file(tok_path) def tokenize_chunk(texts): encs = _worker_tokenizer.encode_batch(texts) return [e.ids for e in encs if len(e.ids) >= 2] # ── Process shard ───────────────────────────────────────────────────────────── def process_shard(name, raw_path, pool): print(f" [{WORKER_ID}] Processing: {name}") out_name = name.replace(".parquet", ".bin") out_path = Path(OUT_DIR) / out_name tmp_path = Path(OUT_DIR) / f"{out_name}.tmp" total_tokens = 0 try: pf = pq.ParquetFile(raw_path) except Exception as e: raw_path.unlink(missing_ok=True) return False, f"read_failed: {e}" try: with open(tmp_path, "wb") as f: for batch in pf.iter_batches(batch_size=5_000, columns=["text"]): texts = batch.column("text").to_pylist() mid = len(texts) // 2 try: results = pool.map(tokenize_chunk, [texts[:mid], texts[mid:]]) except Exception as e: tmp_path.unlink(missing_ok=True) return False, f"tokenize_failed: {e}" for ids in results[0] + results[1]: arr = np.array(ids, dtype=np.uint16) arr.tofile(f) total_tokens += len(ids) del texts, results gc.collect() except Exception as e: tmp_path.unlink(missing_ok=True) return False, f"write_failed: {e}" tmp_path.rename(out_path) print(f" ✓ [{WORKER_ID}] {out_name} | {total_tokens:,} tokens") return True, None # ── Flush memory ────────────────────────────────────────────────────────────── def flush_memory(): gc.collect() try: ctypes.CDLL("libc.so.6").malloc_trim(0) except Exception: pass # ── Mop-up thread — just fixes state, worker_loop does the tokenizing ───────── def mopup_thread(): while True: time.sleep(SCAN_INTERVAL) if not os.path.exists(STATE_FILE): continue try: state = load_state() except Exception: continue changed = False for name, info in state["shards"].items(): status = info["status"] raw_path = Path(RAW_DIR) / name out_path = Path(OUT_DIR) / name.replace(".parquet", ".bin") if status == "done": continue # already tokenized but not marked done if out_path.exists(): print(f" [MOPUP] Already tokenized, marking done: {name}") info["status"] = "done" info["error"] = None changed = True continue # reset failed/claimed back to pending if raw is on disk if status in ("failed", "claimed") and raw_path.exists(): print(f" [MOPUP] Resetting {status} → pending: {name}") info["status"] = "pending" info["worker"] = None info["claimed_at"] = None info["error"] = None info["retries"] = 0 changed = True # re-queue if failed and raw is gone if status == "failed" and not raw_path.exists(): hf_path = info.get("hf_path") queue = state.get("queue", []) if hf_path and hf_path not in queue: print(f" [MOPUP] Re-queuing for download: {name}") queue.append(hf_path) state["queue"] = queue info["status"] = "pending" info["worker"] = None info["claimed_at"] = None info["error"] = None info["retries"] = 0 changed = True # clean orphaned tmps for tmp_file in list(Path(OUT_DIR).glob("*.tmp")) + list(Path(RAW_DIR).glob("*.tmp")): print(f" [MOPUP] Removing orphaned tmp: {tmp_file.name}") tmp_file.unlink(missing_ok=True) if changed: save_state(state) print(f" [MOPUP] State updated — worker_loop will pick up resets") else: print(f" [MOPUP] Nothing to fix") # ── Worker loop — continuous, same as regular tokenizer ────────────────────── def worker_loop(pool): while True: if not os.path.exists(STATE_FILE): print(f" [{WORKER_ID}] Waiting for state.json...") time.sleep(POLL_INTERVAL) continue try: state = load_state() except Exception as e: print(f" [{WORKER_ID}] State read error: {e}") time.sleep(POLL_INTERVAL) continue total = len(state["shards"]) + len(state.get("queue", [])) done = sum(1 for v in state["shards"].values() if v["status"] == "done") if total > 0 and done == total: print(f" [{WORKER_ID}] All done. Sleeping.") time.sleep(300) continue # claim a pending shard claimed_name = None claimed_path = None for name, info in state["shards"].items(): if info["status"] == "pending": raw_path = Path(RAW_DIR) / name if raw_path.exists(): info["status"] = "claimed" info["worker"] = WORKER_ID info["claimed_at"] = time.time() save_state(state) claimed_name = name claimed_path = raw_path break if not claimed_name: print(f" [{WORKER_ID}] Nothing ready — polling in {POLL_INTERVAL}s") time.sleep(POLL_INTERVAL) continue print(f" [{WORKER_ID}] Claimed: {claimed_name}") success, error = process_shard(claimed_name, claimed_path, pool) try: state = load_state() except Exception: pass if success: state["shards"][claimed_name]["status"] = "done" state["shards"][claimed_name]["error"] = None save_state(state) try: claimed_path.unlink() print(f" [{WORKER_ID}] Deleted: {claimed_path.name}") except Exception as e: print(f" [{WORKER_ID}] Delete failed: {e}") else: retries = state["shards"][claimed_name].get("retries", 0) + 1 state["shards"][claimed_name]["retries"] = retries state["shards"][claimed_name]["error"] = error state["shards"][claimed_name]["worker"] = None state["shards"][claimed_name]["claimed_at"] = None state["shards"][claimed_name]["status"] = "failed" if retries >= 3 else "pending" save_state(state) print(f" [{WORKER_ID}] Failed ({error}) retry {retries}/3: {claimed_name}") flush_memory() time.sleep(5) # ── Entry point ─────────────────────────────────────────────────────────────── if __name__ == "__main__": os.makedirs(OUT_DIR, exist_ok=True) print(f"✓ [{WORKER_ID}] Loading tokenizer...") tok = Tokenizer.from_file(TOK_PATH) print(f"✓ [{WORKER_ID}] Tokenizer ready | vocab: {tok.get_vocab_size():,}") del tok flush_memory() pool = mp.Pool(processes=2, initializer=init_worker, initargs=(TOK_PATH,)) print(f"✓ [{WORKER_ID}] Worker pool ready") threading.Thread(target=serve, daemon=True).start() threading.Thread(target=mopup_thread, daemon=True).start() try: worker_loop(pool) finally: pool.terminate() pool.join()