| import os |
| import json |
| import time |
| import socket |
| import threading |
| import gc |
| import ctypes |
| import multiprocessing as mp |
| from pathlib import Path |
| import numpy as np |
| import pyarrow.parquet as pq |
| from tokenizers import Tokenizer |
|
|
| |
| HF_TOKEN = os.environ.get("HF_TOKEN") |
| STATE_FILE = "/data/state.json" |
| RAW_DIR = "/data/raw" |
| OUT_DIR = "/data/tokenized" |
| TOK_PATH = "/data/tokenizer.json" |
| WORKER_ID = socket.gethostname() |
| POLL_INTERVAL = 15 |
| SCAN_INTERVAL = 600 |
|
|
| |
| def serve(): |
| s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
| s.bind(("0.0.0.0", 7860)) |
| s.listen(5) |
| print(f"β [{WORKER_ID}] Listening on port 7860") |
| while True: |
| conn, _ = s.accept() |
| conn.send(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK") |
| conn.close() |
|
|
| |
| def load_state(): |
| with open(STATE_FILE) as f: |
| return json.load(f) |
|
|
| def save_state(state): |
| tmp = STATE_FILE + f".tmp.{WORKER_ID}" |
| with open(tmp, "w") as f: |
| json.dump(state, f, indent=2) |
| os.replace(tmp, STATE_FILE) |
|
|
| |
| _worker_tokenizer = None |
|
|
| def init_worker(tok_path): |
| global _worker_tokenizer |
| _worker_tokenizer = Tokenizer.from_file(tok_path) |
|
|
| def tokenize_chunk(texts): |
| encs = _worker_tokenizer.encode_batch(texts) |
| return [e.ids for e in encs if len(e.ids) >= 2] |
|
|
| |
| def process_shard(name, raw_path, pool): |
| print(f" [{WORKER_ID}] Processing: {name}") |
|
|
| out_name = name.replace(".parquet", ".bin") |
| out_path = Path(OUT_DIR) / out_name |
| tmp_path = Path(OUT_DIR) / f"{out_name}.tmp" |
| total_tokens = 0 |
|
|
| try: |
| pf = pq.ParquetFile(raw_path) |
| except Exception as e: |
| raw_path.unlink(missing_ok=True) |
| return False, f"read_failed: {e}" |
|
|
| try: |
| with open(tmp_path, "wb") as f: |
| for batch in pf.iter_batches(batch_size=5_000, columns=["text"]): |
| texts = batch.column("text").to_pylist() |
| mid = len(texts) // 2 |
|
|
| try: |
| results = pool.map(tokenize_chunk, [texts[:mid], texts[mid:]]) |
| except Exception as e: |
| tmp_path.unlink(missing_ok=True) |
| return False, f"tokenize_failed: {e}" |
|
|
| for ids in results[0] + results[1]: |
| arr = np.array(ids, dtype=np.uint16) |
| arr.tofile(f) |
| total_tokens += len(ids) |
|
|
| del texts, results |
| gc.collect() |
|
|
| except Exception as e: |
| tmp_path.unlink(missing_ok=True) |
| return False, f"write_failed: {e}" |
|
|
| tmp_path.rename(out_path) |
| print(f" β [{WORKER_ID}] {out_name} | {total_tokens:,} tokens") |
| return True, None |
|
|
| |
| def flush_memory(): |
| gc.collect() |
| try: |
| ctypes.CDLL("libc.so.6").malloc_trim(0) |
| except Exception: |
| pass |
|
|
| |
| def mopup_thread(): |
| while True: |
| time.sleep(SCAN_INTERVAL) |
|
|
| if not os.path.exists(STATE_FILE): |
| continue |
|
|
| try: |
| state = load_state() |
| except Exception: |
| continue |
|
|
| changed = False |
| for name, info in state["shards"].items(): |
| status = info["status"] |
| raw_path = Path(RAW_DIR) / name |
| out_path = Path(OUT_DIR) / name.replace(".parquet", ".bin") |
|
|
| if status == "done": |
| continue |
|
|
| |
| if out_path.exists(): |
| print(f" [MOPUP] Already tokenized, marking done: {name}") |
| info["status"] = "done" |
| info["error"] = None |
| changed = True |
| continue |
|
|
| |
| if status in ("failed", "claimed") and raw_path.exists(): |
| print(f" [MOPUP] Resetting {status} β pending: {name}") |
| info["status"] = "pending" |
| info["worker"] = None |
| info["claimed_at"] = None |
| info["error"] = None |
| info["retries"] = 0 |
| changed = True |
|
|
| |
| if status == "failed" and not raw_path.exists(): |
| hf_path = info.get("hf_path") |
| queue = state.get("queue", []) |
| if hf_path and hf_path not in queue: |
| print(f" [MOPUP] Re-queuing for download: {name}") |
| queue.append(hf_path) |
| state["queue"] = queue |
| info["status"] = "pending" |
| info["worker"] = None |
| info["claimed_at"] = None |
| info["error"] = None |
| info["retries"] = 0 |
| changed = True |
|
|
| |
| for tmp_file in list(Path(OUT_DIR).glob("*.tmp")) + list(Path(RAW_DIR).glob("*.tmp")): |
| print(f" [MOPUP] Removing orphaned tmp: {tmp_file.name}") |
| tmp_file.unlink(missing_ok=True) |
|
|
| if changed: |
| save_state(state) |
| print(f" [MOPUP] State updated β worker_loop will pick up resets") |
| else: |
| print(f" [MOPUP] Nothing to fix") |
|
|
| |
| def worker_loop(pool): |
| while True: |
| if not os.path.exists(STATE_FILE): |
| print(f" [{WORKER_ID}] Waiting for state.json...") |
| time.sleep(POLL_INTERVAL) |
| continue |
|
|
| try: |
| state = load_state() |
| except Exception as e: |
| print(f" [{WORKER_ID}] State read error: {e}") |
| time.sleep(POLL_INTERVAL) |
| continue |
|
|
| total = len(state["shards"]) + len(state.get("queue", [])) |
| done = sum(1 for v in state["shards"].values() if v["status"] == "done") |
| if total > 0 and done == total: |
| print(f" [{WORKER_ID}] All done. Sleeping.") |
| time.sleep(300) |
| continue |
|
|
| |
| claimed_name = None |
| claimed_path = None |
| for name, info in state["shards"].items(): |
| if info["status"] == "pending": |
| raw_path = Path(RAW_DIR) / name |
| if raw_path.exists(): |
| info["status"] = "claimed" |
| info["worker"] = WORKER_ID |
| info["claimed_at"] = time.time() |
| save_state(state) |
| claimed_name = name |
| claimed_path = raw_path |
| break |
|
|
| if not claimed_name: |
| print(f" [{WORKER_ID}] Nothing ready β polling in {POLL_INTERVAL}s") |
| time.sleep(POLL_INTERVAL) |
| continue |
|
|
| print(f" [{WORKER_ID}] Claimed: {claimed_name}") |
| success, error = process_shard(claimed_name, claimed_path, pool) |
|
|
| try: |
| state = load_state() |
| except Exception: |
| pass |
|
|
| if success: |
| state["shards"][claimed_name]["status"] = "done" |
| state["shards"][claimed_name]["error"] = None |
| save_state(state) |
| try: |
| claimed_path.unlink() |
| print(f" [{WORKER_ID}] Deleted: {claimed_path.name}") |
| except Exception as e: |
| print(f" [{WORKER_ID}] Delete failed: {e}") |
| else: |
| retries = state["shards"][claimed_name].get("retries", 0) + 1 |
| state["shards"][claimed_name]["retries"] = retries |
| state["shards"][claimed_name]["error"] = error |
| state["shards"][claimed_name]["worker"] = None |
| state["shards"][claimed_name]["claimed_at"] = None |
| state["shards"][claimed_name]["status"] = "failed" if retries >= 3 else "pending" |
| save_state(state) |
| print(f" [{WORKER_ID}] Failed ({error}) retry {retries}/3: {claimed_name}") |
|
|
| flush_memory() |
| time.sleep(5) |
|
|
| |
| if __name__ == "__main__": |
| os.makedirs(OUT_DIR, exist_ok=True) |
|
|
| print(f"β [{WORKER_ID}] Loading tokenizer...") |
| tok = Tokenizer.from_file(TOK_PATH) |
| print(f"β [{WORKER_ID}] Tokenizer ready | vocab: {tok.get_vocab_size():,}") |
| del tok |
| flush_memory() |
|
|
| pool = mp.Pool(processes=2, initializer=init_worker, initargs=(TOK_PATH,)) |
| print(f"β [{WORKER_ID}] Worker pool ready") |
|
|
| threading.Thread(target=serve, daemon=True).start() |
| threading.Thread(target=mopup_thread, daemon=True).start() |
|
|
| try: |
| worker_loop(pool) |
| finally: |
| pool.terminate() |
| pool.join() |