sllm / tokenizer /tokenize_dataset.py
geeteshcodes's picture
Initial commit
7f974df verified
"""
tokenize_dataset.py — Parallel tokenization pipeline
Architecture:
Main thread : stream HF dataset → filter → normalize → batch texts
Worker pool : N_WORKERS processes, each with own loaded tokenizer,
tokenize batches concurrently using ProcessPoolExecutor
Main thread : collect results IN ORDER → route train/val → flush shards
Why this is faster:
Old code: stream → [normalize] → [tokenize 1000 docs, 1 CPU] → write
New code: stream → [normalize] → [tokenize 1000 docs × N cores] → write
On 12-core machine: expect 6-10× speedup on tokenization step.
Bottleneck shifts to HF streaming bandwidth, not CPU.
Notes:
- Workers are initialized ONCE with the tokenizer loaded (no repeated disk reads)
- Results collected in SUBMISSION ORDER so train/val routing is deterministic
- Sliding window of MAX_PENDING futures keeps all cores busy without
unbounded memory growth
- Ctrl+C safe: flushes remaining buffers before exit
"""
import os
import sys
import time
import warnings
import numpy as np
from collections import deque
from concurrent.futures import ProcessPoolExecutor
from datasets import load_dataset
from transformers import PreTrainedTokenizerFast, logging as hf_logging
from tqdm import tqdm
# Import normalizer from same directory
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from normalizer import normalization
hf_logging.set_verbosity_error()
warnings.filterwarnings("ignore")
# ------------------------------------------------------------------ #
# CONSTANTS
# ------------------------------------------------------------------ #
DATASET_NAME = "HuggingFaceFW/fineweb-edu"
DATASET_SUBSET = "CC-MAIN-2014-49"
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
TOKENIZER_DIR = os.path.join(SCRIPT_DIR, "fineweb_edu_tokenizer")
DATA_DIR = os.path.join(SCRIPT_DIR, "data")
MIN_QUALITY = 3
SHARD_SIZE = 100_000_000 # tokens per shard (~190 MB at uint16)
BATCH_SIZE = 2_000 # docs per tokenization task (↑ from 1000)
VAL_RATIO = 100 # every 100th accepted doc → val
SHUFFLE_BUFFER = 10_000
MIN_DOC_LENGTH = 100
DTYPE = np.uint16
MAX_TOKENS = 3_200_000_000
# Parallel workers: leave 2 cores for OS + HF streaming
N_WORKERS = max(1, os.cpu_count() - 2)
# How many tokenization futures to keep in-flight at once
# = N_WORKERS × 2 keeps the pipeline full without excess memory
MAX_PENDING = N_WORKERS * 2
# ------------------------------------------------------------------ #
# WORKER PROCESS — loaded once per process at startup
# ------------------------------------------------------------------ #
# Module-level tokenizer in each worker process
_worker_tokenizer = None
def _worker_init(tokenizer_dir: str):
"""
Called ONCE per worker process at startup.
Loads the tokenizer into the worker's global state.
Subsequent calls to _tokenize_worker_fn reuse this loaded tokenizer.
"""
global _worker_tokenizer
import warnings
from transformers import PreTrainedTokenizerFast, logging as hf_log
hf_log.set_verbosity_error()
warnings.filterwarnings("ignore")
_worker_tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_dir)
def _tokenize_worker_fn(texts: list) -> list:
"""
Tokenizes a batch of pre-normalized texts in a worker process.
Returns a list of token-ID lists, one per document.
Each doc ends with <|endoftext|> (added by add_special_tokens=True).
Args:
texts : list of normalized strings (already filtered, normalized)
Returns:
list of list[int] — token IDs per document
"""
global _worker_tokenizer
encoded = _worker_tokenizer(
texts,
add_special_tokens = True, # appends <|endoftext|>
truncation = False, # keep full document
padding = False, # no padding (we pack shards)
return_attention_mask= False, # not needed
)
return encoded["input_ids"]
# ------------------------------------------------------------------ #
# SHARD HELPERS
# ------------------------------------------------------------------ #
def get_shard_path(split: str, shard_idx: int) -> str:
return os.path.join(DATA_DIR, f"{split}_{shard_idx:03d}.bin")
def save_shard(tokens: list, split: str, shard_idx: int):
arr = np.array(tokens, dtype=DTYPE)
path = get_shard_path(split, shard_idx)
arr.tofile(path)
size_mb = arr.nbytes / 1024 / 1024
tqdm.write(f" saved {split}_{shard_idx:03d}.bin | {len(tokens):,} tokens | {size_mb:.1f} MB")
# ------------------------------------------------------------------ #
# ROUTE BATCH RESULTS → train / val buffers
# ------------------------------------------------------------------ #
def route_results(
all_ids : list,
doc_count_start: int,
train_buffer : list,
val_buffer : list,
train_tokens : int,
val_tokens : int,
total_tokens : int,
) -> tuple:
"""
Routes tokenized docs to train or val buffer by doc index.
Returns updated (train_buffer, val_buffer, train_tokens, val_tokens, total_tokens, batch_tok_count).
"""
batch_tok_count = 0
for i, ids in enumerate(all_ids):
doc_num = doc_count_start + i
if doc_num % VAL_RATIO == 0: # every 100th doc → val
val_buffer.extend(ids)
val_tokens += len(ids)
else:
train_buffer.extend(ids)
train_tokens += len(ids)
total_tokens += len(ids)
batch_tok_count += len(ids)
return train_buffer, val_buffer, train_tokens, val_tokens, total_tokens, batch_tok_count
# ------------------------------------------------------------------ #
# MAIN PARALLEL TOKENIZATION PIPELINE
# ------------------------------------------------------------------ #
def tokenize_dataset():
os.makedirs(DATA_DIR, exist_ok=True)
print(f"Loading tokenizer from: {TOKENIZER_DIR}")
print(f" workers : {N_WORKERS} of {os.cpu_count()} CPUs")
print(f"\nLoading dataset stream: {DATASET_NAME} / {DATASET_SUBSET}")
ds = load_dataset(
DATASET_NAME,
name = DATASET_SUBSET,
split = "train",
streaming = True,
).shuffle(buffer_size=SHUFFLE_BUFFER, seed=42)
# ---- State ------------------------------------------------------ #
train_buffer = []
val_buffer = []
train_shard = 0
val_shard = 0
total_docs = 0
skipped_docs = 0
total_tokens = 0
train_tokens = 0
val_tokens = 0
batch_texts = [] # accumulating next batch to submit
batch_doc_start = 0 # doc index at start of current batch_texts
# pending: deque of (future, doc_count_start)
# We always pop from the LEFT (oldest submission) to preserve order
pending = deque()
cap_reached = False
# ---- Progress bars ----------------------------------------------- #
token_bar = tqdm(
total=MAX_TOKENS,
desc="tokens",
unit="tok",
unit_scale=True,
unit_divisor=1000,
colour="green",
position=0,
)
doc_bar = tqdm(
desc="docs ",
unit="doc",
unit_scale=True,
colour="blue",
position=1,
)
t_start = time.time()
# ------------------------------------------------------------------ #
# DRAIN HELPER — collect the oldest pending future and process it
# ------------------------------------------------------------------ #
def drain_one():
nonlocal train_buffer, val_buffer, train_shard, val_shard
nonlocal total_tokens, train_tokens, val_tokens
if not pending:
return False
future, doc_start = pending.popleft()
all_ids = future.result() # blocks until this task done
(train_buffer, val_buffer,
train_tokens, val_tokens,
total_tokens, batch_tok) = route_results(
all_ids, doc_start,
train_buffer, val_buffer,
train_tokens, val_tokens, total_tokens,
)
token_bar.update(batch_tok)
token_bar.set_postfix({
"train": f"{train_tokens/1e9:.2f}B",
"val" : f"{val_tokens/1e6:.0f}M",
"shards": train_shard,
})
# Flush train shards
while len(train_buffer) >= SHARD_SIZE:
save_shard(train_buffer[:SHARD_SIZE], "train", train_shard)
train_buffer = train_buffer[SHARD_SIZE:]
train_shard += 1
# Flush val shards
while len(val_buffer) >= SHARD_SIZE:
save_shard(val_buffer[:SHARD_SIZE], "val", val_shard)
val_buffer = val_buffer[SHARD_SIZE:]
val_shard += 1
return True
# ------------------------------------------------------------------ #
# MAIN LOOP with ProcessPoolExecutor
# ------------------------------------------------------------------ #
print(f"\nStarting tokenization...")
print(f" token target : {MAX_TOKENS:,}")
print(f" shard size : {SHARD_SIZE:,} tokens")
print(f" batch size : {BATCH_SIZE} docs")
print(f" val ratio : every {VAL_RATIO}th doc")
print(f" quality : int_score >= {MIN_QUALITY}\n")
with ProcessPoolExecutor(
max_workers = N_WORKERS,
initializer = _worker_init,
initargs = (TOKENIZER_DIR,),
) as executor:
for doc in ds:
# ---- Quality filter ------------------------------------ #
if doc["int_score"] < MIN_QUALITY:
skipped_docs += 1
doc_bar.set_postfix({"skipped": skipped_docs})
continue
# ---- Length + normalize -------------------------------- #
text = doc["text"]
if len(text) < MIN_DOC_LENGTH:
skipped_docs += 1
doc_bar.set_postfix({"skipped": skipped_docs})
continue
text = normalization(text)
if len(text) < MIN_DOC_LENGTH:
skipped_docs += 1
doc_bar.set_postfix({"skipped": skipped_docs})
continue
batch_texts.append(text)
total_docs += 1
doc_bar.update(1)
# ---- Submit batch when full ---------------------------- #
if len(batch_texts) == BATCH_SIZE:
# Record which doc index this batch starts at
doc_start = total_docs - BATCH_SIZE
future = executor.submit(_tokenize_worker_fn, batch_texts)
pending.append((future, doc_start))
batch_texts = []
# ---- Backpressure: drain oldest if queue full ------- #
# This prevents unbounded memory accumulation
# while keeping all N_WORKERS busy
while len(pending) >= MAX_PENDING:
drain_one()
# ---- Check token cap -------------------------------- #
if total_tokens >= MAX_TOKENS:
tqdm.write(f"\nToken cap reached: {total_tokens:,} tokens from {total_docs:,} docs")
cap_reached = True
break
# ---- Submit any remaining partial batch -------------------- #
if batch_texts and not cap_reached:
doc_start = total_docs - len(batch_texts)
future = executor.submit(_tokenize_worker_fn, batch_texts)
pending.append((future, doc_start))
# ---- Drain all remaining pending futures ------------------- #
while pending:
drain_one()
# ---- Close progress bars --------------------------------------- #
token_bar.close()
doc_bar.close()
# ---- Save remaining partial shards ----------------------------- #
if train_buffer:
save_shard(train_buffer, "train", train_shard)
train_shard += 1
if val_buffer:
save_shard(val_buffer, "val", val_shard)
val_shard += 1
# ---- Final summary --------------------------------------------- #
print(f"\n{'='*60}")
print(f" TOKENIZATION COMPLETE")
print(f"{'='*60}")
print(f" total docs : {total_docs:,}")
print(f" skipped docs : {skipped_docs:,}")
print(f" total tokens : {total_tokens:,}")
print(f" train tokens : {train_tokens:,}")
print(f" val tokens : {val_tokens:,}")
print(f" train shards : {train_shard}")
print(f" val shards : {val_shard}")
print(f" data dir : {os.path.abspath(DATA_DIR)}")
# ------------------------------------------------------------------ #
# LOAD SHARDS DURING TRAINING (unchanged)
# ------------------------------------------------------------------ #
def load_shard(split: str, shard_idx: int) -> np.ndarray:
"""
Loads a shard as a memory-mapped numpy array.
The full shard never loads into RAM at once.
Usage during training:
shard = load_shard("train", 0)
chunk = shard[i : i + 1024]
"""
path = get_shard_path(split, shard_idx)
return np.memmap(path, dtype=DTYPE, mode="r")
# ------------------------------------------------------------------ #
# ENTRY POINT
# ------------------------------------------------------------------ #
if __name__ == "__main__":
# Windows requires this guard for multiprocessing with spawn start method
tokenize_dataset()