""" tokenize_dataset.py — Parallel tokenization pipeline Architecture: Main thread : stream HF dataset → filter → normalize → batch texts Worker pool : N_WORKERS processes, each with own loaded tokenizer, tokenize batches concurrently using ProcessPoolExecutor Main thread : collect results IN ORDER → route train/val → flush shards Why this is faster: Old code: stream → [normalize] → [tokenize 1000 docs, 1 CPU] → write New code: stream → [normalize] → [tokenize 1000 docs × N cores] → write On 12-core machine: expect 6-10× speedup on tokenization step. Bottleneck shifts to HF streaming bandwidth, not CPU. Notes: - Workers are initialized ONCE with the tokenizer loaded (no repeated disk reads) - Results collected in SUBMISSION ORDER so train/val routing is deterministic - Sliding window of MAX_PENDING futures keeps all cores busy without unbounded memory growth - Ctrl+C safe: flushes remaining buffers before exit """ import os import sys import time import warnings import numpy as np from collections import deque from concurrent.futures import ProcessPoolExecutor from datasets import load_dataset from transformers import PreTrainedTokenizerFast, logging as hf_logging from tqdm import tqdm # Import normalizer from same directory sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from normalizer import normalization hf_logging.set_verbosity_error() warnings.filterwarnings("ignore") # ------------------------------------------------------------------ # # CONSTANTS # ------------------------------------------------------------------ # DATASET_NAME = "HuggingFaceFW/fineweb-edu" DATASET_SUBSET = "CC-MAIN-2014-49" SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) TOKENIZER_DIR = os.path.join(SCRIPT_DIR, "fineweb_edu_tokenizer") DATA_DIR = os.path.join(SCRIPT_DIR, "data") MIN_QUALITY = 3 SHARD_SIZE = 100_000_000 # tokens per shard (~190 MB at uint16) BATCH_SIZE = 2_000 # docs per tokenization task (↑ from 1000) VAL_RATIO = 100 # every 100th accepted doc → val SHUFFLE_BUFFER = 10_000 MIN_DOC_LENGTH = 100 DTYPE = np.uint16 MAX_TOKENS = 3_200_000_000 # Parallel workers: leave 2 cores for OS + HF streaming N_WORKERS = max(1, os.cpu_count() - 2) # How many tokenization futures to keep in-flight at once # = N_WORKERS × 2 keeps the pipeline full without excess memory MAX_PENDING = N_WORKERS * 2 # ------------------------------------------------------------------ # # WORKER PROCESS — loaded once per process at startup # ------------------------------------------------------------------ # # Module-level tokenizer in each worker process _worker_tokenizer = None def _worker_init(tokenizer_dir: str): """ Called ONCE per worker process at startup. Loads the tokenizer into the worker's global state. Subsequent calls to _tokenize_worker_fn reuse this loaded tokenizer. """ global _worker_tokenizer import warnings from transformers import PreTrainedTokenizerFast, logging as hf_log hf_log.set_verbosity_error() warnings.filterwarnings("ignore") _worker_tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_dir) def _tokenize_worker_fn(texts: list) -> list: """ Tokenizes a batch of pre-normalized texts in a worker process. Returns a list of token-ID lists, one per document. Each doc ends with <|endoftext|> (added by add_special_tokens=True). Args: texts : list of normalized strings (already filtered, normalized) Returns: list of list[int] — token IDs per document """ global _worker_tokenizer encoded = _worker_tokenizer( texts, add_special_tokens = True, # appends <|endoftext|> truncation = False, # keep full document padding = False, # no padding (we pack shards) return_attention_mask= False, # not needed ) return encoded["input_ids"] # ------------------------------------------------------------------ # # SHARD HELPERS # ------------------------------------------------------------------ # def get_shard_path(split: str, shard_idx: int) -> str: return os.path.join(DATA_DIR, f"{split}_{shard_idx:03d}.bin") def save_shard(tokens: list, split: str, shard_idx: int): arr = np.array(tokens, dtype=DTYPE) path = get_shard_path(split, shard_idx) arr.tofile(path) size_mb = arr.nbytes / 1024 / 1024 tqdm.write(f" saved {split}_{shard_idx:03d}.bin | {len(tokens):,} tokens | {size_mb:.1f} MB") # ------------------------------------------------------------------ # # ROUTE BATCH RESULTS → train / val buffers # ------------------------------------------------------------------ # def route_results( all_ids : list, doc_count_start: int, train_buffer : list, val_buffer : list, train_tokens : int, val_tokens : int, total_tokens : int, ) -> tuple: """ Routes tokenized docs to train or val buffer by doc index. Returns updated (train_buffer, val_buffer, train_tokens, val_tokens, total_tokens, batch_tok_count). """ batch_tok_count = 0 for i, ids in enumerate(all_ids): doc_num = doc_count_start + i if doc_num % VAL_RATIO == 0: # every 100th doc → val val_buffer.extend(ids) val_tokens += len(ids) else: train_buffer.extend(ids) train_tokens += len(ids) total_tokens += len(ids) batch_tok_count += len(ids) return train_buffer, val_buffer, train_tokens, val_tokens, total_tokens, batch_tok_count # ------------------------------------------------------------------ # # MAIN PARALLEL TOKENIZATION PIPELINE # ------------------------------------------------------------------ # def tokenize_dataset(): os.makedirs(DATA_DIR, exist_ok=True) print(f"Loading tokenizer from: {TOKENIZER_DIR}") print(f" workers : {N_WORKERS} of {os.cpu_count()} CPUs") print(f"\nLoading dataset stream: {DATASET_NAME} / {DATASET_SUBSET}") ds = load_dataset( DATASET_NAME, name = DATASET_SUBSET, split = "train", streaming = True, ).shuffle(buffer_size=SHUFFLE_BUFFER, seed=42) # ---- State ------------------------------------------------------ # train_buffer = [] val_buffer = [] train_shard = 0 val_shard = 0 total_docs = 0 skipped_docs = 0 total_tokens = 0 train_tokens = 0 val_tokens = 0 batch_texts = [] # accumulating next batch to submit batch_doc_start = 0 # doc index at start of current batch_texts # pending: deque of (future, doc_count_start) # We always pop from the LEFT (oldest submission) to preserve order pending = deque() cap_reached = False # ---- Progress bars ----------------------------------------------- # token_bar = tqdm( total=MAX_TOKENS, desc="tokens", unit="tok", unit_scale=True, unit_divisor=1000, colour="green", position=0, ) doc_bar = tqdm( desc="docs ", unit="doc", unit_scale=True, colour="blue", position=1, ) t_start = time.time() # ------------------------------------------------------------------ # # DRAIN HELPER — collect the oldest pending future and process it # ------------------------------------------------------------------ # def drain_one(): nonlocal train_buffer, val_buffer, train_shard, val_shard nonlocal total_tokens, train_tokens, val_tokens if not pending: return False future, doc_start = pending.popleft() all_ids = future.result() # blocks until this task done (train_buffer, val_buffer, train_tokens, val_tokens, total_tokens, batch_tok) = route_results( all_ids, doc_start, train_buffer, val_buffer, train_tokens, val_tokens, total_tokens, ) token_bar.update(batch_tok) token_bar.set_postfix({ "train": f"{train_tokens/1e9:.2f}B", "val" : f"{val_tokens/1e6:.0f}M", "shards": train_shard, }) # Flush train shards while len(train_buffer) >= SHARD_SIZE: save_shard(train_buffer[:SHARD_SIZE], "train", train_shard) train_buffer = train_buffer[SHARD_SIZE:] train_shard += 1 # Flush val shards while len(val_buffer) >= SHARD_SIZE: save_shard(val_buffer[:SHARD_SIZE], "val", val_shard) val_buffer = val_buffer[SHARD_SIZE:] val_shard += 1 return True # ------------------------------------------------------------------ # # MAIN LOOP with ProcessPoolExecutor # ------------------------------------------------------------------ # print(f"\nStarting tokenization...") print(f" token target : {MAX_TOKENS:,}") print(f" shard size : {SHARD_SIZE:,} tokens") print(f" batch size : {BATCH_SIZE} docs") print(f" val ratio : every {VAL_RATIO}th doc") print(f" quality : int_score >= {MIN_QUALITY}\n") with ProcessPoolExecutor( max_workers = N_WORKERS, initializer = _worker_init, initargs = (TOKENIZER_DIR,), ) as executor: for doc in ds: # ---- Quality filter ------------------------------------ # if doc["int_score"] < MIN_QUALITY: skipped_docs += 1 doc_bar.set_postfix({"skipped": skipped_docs}) continue # ---- Length + normalize -------------------------------- # text = doc["text"] if len(text) < MIN_DOC_LENGTH: skipped_docs += 1 doc_bar.set_postfix({"skipped": skipped_docs}) continue text = normalization(text) if len(text) < MIN_DOC_LENGTH: skipped_docs += 1 doc_bar.set_postfix({"skipped": skipped_docs}) continue batch_texts.append(text) total_docs += 1 doc_bar.update(1) # ---- Submit batch when full ---------------------------- # if len(batch_texts) == BATCH_SIZE: # Record which doc index this batch starts at doc_start = total_docs - BATCH_SIZE future = executor.submit(_tokenize_worker_fn, batch_texts) pending.append((future, doc_start)) batch_texts = [] # ---- Backpressure: drain oldest if queue full ------- # # This prevents unbounded memory accumulation # while keeping all N_WORKERS busy while len(pending) >= MAX_PENDING: drain_one() # ---- Check token cap -------------------------------- # if total_tokens >= MAX_TOKENS: tqdm.write(f"\nToken cap reached: {total_tokens:,} tokens from {total_docs:,} docs") cap_reached = True break # ---- Submit any remaining partial batch -------------------- # if batch_texts and not cap_reached: doc_start = total_docs - len(batch_texts) future = executor.submit(_tokenize_worker_fn, batch_texts) pending.append((future, doc_start)) # ---- Drain all remaining pending futures ------------------- # while pending: drain_one() # ---- Close progress bars --------------------------------------- # token_bar.close() doc_bar.close() # ---- Save remaining partial shards ----------------------------- # if train_buffer: save_shard(train_buffer, "train", train_shard) train_shard += 1 if val_buffer: save_shard(val_buffer, "val", val_shard) val_shard += 1 # ---- Final summary --------------------------------------------- # print(f"\n{'='*60}") print(f" TOKENIZATION COMPLETE") print(f"{'='*60}") print(f" total docs : {total_docs:,}") print(f" skipped docs : {skipped_docs:,}") print(f" total tokens : {total_tokens:,}") print(f" train tokens : {train_tokens:,}") print(f" val tokens : {val_tokens:,}") print(f" train shards : {train_shard}") print(f" val shards : {val_shard}") print(f" data dir : {os.path.abspath(DATA_DIR)}") # ------------------------------------------------------------------ # # LOAD SHARDS DURING TRAINING (unchanged) # ------------------------------------------------------------------ # def load_shard(split: str, shard_idx: int) -> np.ndarray: """ Loads a shard as a memory-mapped numpy array. The full shard never loads into RAM at once. Usage during training: shard = load_shard("train", 0) chunk = shard[i : i + 1024] """ path = get_shard_path(split, shard_idx) return np.memmap(path, dtype=DTYPE, mode="r") # ------------------------------------------------------------------ # # ENTRY POINT # ------------------------------------------------------------------ # if __name__ == "__main__": # Windows requires this guard for multiprocessing with spawn start method tokenize_dataset()