| import os |
| import json |
| import time |
| import socket |
| import threading |
| import gc |
| import ctypes |
| import multiprocessing as mp |
| from pathlib import Path |
| import numpy as np |
| import pyarrow.parquet as pq |
| from tokenizers import Tokenizer |
|
|
| |
| HF_TOKEN = os.environ.get("HF_TOKEN") |
| STATE_FILE = "/data/state.json" |
| RAW_DIR = "/data/raw" |
| OUT_DIR = "/data/tokenized" |
| TOK_PATH = "/data/tokenizer.json" |
| WORKER_ID = socket.gethostname() |
| POLL_INTERVAL = 15 |
|
|
| os.makedirs(OUT_DIR, exist_ok=True) |
|
|
| |
| def serve(): |
| s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
| s.bind(("0.0.0.0", 7860)) |
| s.listen(5) |
| print(f"β [{WORKER_ID}] Listening on port 7860") |
| while True: |
| conn, _ = s.accept() |
| conn.send(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK") |
| conn.close() |
|
|
| |
| def load_state(): |
| with open(STATE_FILE) as f: |
| return json.load(f) |
|
|
| def save_state(state): |
| tmp = STATE_FILE + f".tmp.{WORKER_ID}" |
| with open(tmp, "w") as f: |
| json.dump(state, f, indent=2) |
| os.replace(tmp, STATE_FILE) |
|
|
| |
| def claim_shard(state): |
| for name, info in state["shards"].items(): |
| if info["status"] == "pending": |
| raw_path = Path(RAW_DIR) / name |
| if raw_path.exists(): |
| info["status"] = "claimed" |
| info["worker"] = WORKER_ID |
| info["claimed_at"] = time.time() |
| save_state(state) |
| return name, raw_path |
| return None, None |
|
|
| |
| _worker_tokenizer = None |
|
|
| def init_worker(tok_path): |
| global _worker_tokenizer |
| _worker_tokenizer = Tokenizer.from_file(tok_path) |
|
|
| def tokenize_chunk(texts): |
| encs = _worker_tokenizer.encode_batch(texts) |
| return [e.ids for e in encs if len(e.ids) >= 2] |
|
|
| |
| def process_shard(name, raw_path, pool): |
| print(f" [{WORKER_ID}] Processing: {name}") |
|
|
| out_name = name.replace(".parquet", ".bin") |
| out_path = Path(OUT_DIR) / out_name |
| tmp_path = Path(OUT_DIR) / f"{out_name}.tmp" |
| total_tokens = 0 |
|
|
| try: |
| pf = pq.ParquetFile(raw_path) |
| except Exception as e: |
| raw_path.unlink(missing_ok=True) |
| return False, f"read_failed: {e}" |
|
|
| try: |
| with open(tmp_path, "wb") as f: |
| for batch in pf.iter_batches(batch_size=5_000, columns=["text"]): |
| texts = batch.column("text").to_pylist() |
| mid = len(texts) // 2 |
|
|
| try: |
| results = pool.map(tokenize_chunk, [texts[:mid], texts[mid:]]) |
| except Exception as e: |
| tmp_path.unlink(missing_ok=True) |
| return False, f"tokenize_failed: {e}" |
|
|
| for ids in results[0] + results[1]: |
| arr = np.array(ids, dtype=np.uint16) |
| arr.tofile(f) |
| total_tokens += len(ids) |
|
|
| del texts, results |
| gc.collect() |
|
|
| except Exception as e: |
| tmp_path.unlink(missing_ok=True) |
| return False, f"write_failed: {e}" |
|
|
| tmp_path.rename(out_path) |
| print(f" β [{WORKER_ID}] {out_name} | {total_tokens:,} tokens") |
| return True, None |
|
|
| |
| def flush_memory(): |
| gc.collect() |
| try: |
| ctypes.CDLL("libc.so.6").malloc_trim(0) |
| except Exception: |
| pass |
|
|
| |
| def worker_loop(): |
| os.makedirs(OUT_DIR, exist_ok=True) |
| print(f"β [{WORKER_ID}] Loading tokenizer...") |
| tok = Tokenizer.from_file(TOK_PATH) |
| print(f"β [{WORKER_ID}] Tokenizer ready | vocab: {tok.get_vocab_size():,}") |
| del tok |
| flush_memory() |
|
|
| pool = mp.Pool(processes=2, initializer=init_worker, initargs=(TOK_PATH,)) |
| print(f"β [{WORKER_ID}] Worker pool ready") |
|
|
| try: |
| while True: |
| if not os.path.exists(STATE_FILE): |
| print(f" [{WORKER_ID}] Waiting for state.json...") |
| time.sleep(POLL_INTERVAL) |
| continue |
|
|
| try: |
| state = load_state() |
| except Exception as e: |
| print(f" [{WORKER_ID}] State read error: {e}") |
| time.sleep(POLL_INTERVAL) |
| continue |
|
|
| total = len(state["shards"]) + len(state.get("queue", [])) |
| done = sum(1 for v in state["shards"].values() if v["status"] == "done") |
| if total > 0 and done == total: |
| print(f" [{WORKER_ID}] All done. Sleeping.") |
| time.sleep(300) |
| continue |
|
|
| name, raw_path = claim_shard(state) |
|
|
| if not name: |
| print(f" [{WORKER_ID}] Nothing ready β polling in {POLL_INTERVAL}s") |
| time.sleep(POLL_INTERVAL) |
| continue |
|
|
| print(f" [{WORKER_ID}] Claimed: {name}") |
| success, error = process_shard(name, raw_path, pool) |
|
|
| try: |
| state = load_state() |
| except Exception: |
| pass |
|
|
| if success: |
| state["shards"][name]["status"] = "done" |
| state["shards"][name]["error"] = None |
| save_state(state) |
| try: |
| raw_path.unlink() |
| print(f" [{WORKER_ID}] Deleted: {raw_path.name}") |
| except Exception as e: |
| print(f" [{WORKER_ID}] Delete failed: {e}") |
| else: |
| state["shards"][name]["status"] = "pending" |
| state["shards"][name]["worker"] = None |
| state["shards"][name]["claimed_at"] = None |
| state["shards"][name]["error"] = error |
| save_state(state) |
| print(f" [{WORKER_ID}] Failed ({error}) β left on disk for retry: {name}") |
|
|
| flush_memory() |
| time.sleep(5) |
|
|
| finally: |
| pool.terminate() |
| pool.join() |
|
|
| |
| if __name__ == "__main__": |
| threading.Thread(target=serve, daemon=True).start() |
| threading.Thread(target=worker_loop, daemon=True).start() |
| while True: |
| time.sleep(60) |