""" Dataset Tokenizer ================= Tokenizes datasets and saves as binary files for mmap loading. Usage: # HuggingFace dataset python -m freqformer.tokenize_data --dataset wikitext --subset wikitext-2-raw-v1 --out data/wiki2 # Custom text files python -m freqformer.tokenize_data --files train.txt val.txt --out data/custom # Append more data to existing binary files (does not overwrite) python -m freqformer.tokenize_data --dataset openwebtext --out data/mixed --append python -m freqformer.tokenize_data --dataset allenai/dolma3_pool --out data/mixed --append --streaming """ import argparse import os import re import struct import numpy as np from pathlib import Path from transformers import AutoTokenizer HEADER_MAGIC = b"FREQTOK1" # 8 bytes magic HEADER_SIZE = 32 # magic(8) + version(4) + num_tokens(8) + seq_len(4) + vocab_size(4) + reserved(4) def write_binary(tokens: list[int], path: str, vocab_size: int): """Write tokenized data as binary file with header.""" os.makedirs(os.path.dirname(path) or ".", exist_ok=True) arr = np.array(tokens, dtype=np.uint16 if vocab_size < 65536 else np.uint32) with open(path, "wb") as f: # Header f.write(HEADER_MAGIC) f.write(struct.pack("= 65535: raise ValueError( f"Cannot append: token ID {max_id} exceeds uint16 range but existing file uses uint16. " f"Recreate the file with a larger vocab_size tokenizer." ) arr = np.array(tokens, dtype=dtype) new_total = old_num_tokens + len(tokens) # Append data and update header with open(path, "r+b") as f: # Update num_tokens in header (offset 12 = magic(8) + version(4)) f.seek(12) f.write(struct.pack(" int: with open(path, "rb") as f: magic = f.read(8) if magic != HEADER_MAGIC: raise ValueError(f"{path} is not a valid FREQTOK1 file") _version = struct.unpack(" 0: _init_binary(val_path, vocab_size) else: if not os.path.exists(train_path): _init_binary(train_path, vocab_size) if val_ratio > 0 and not os.path.exists(val_path): _init_binary(val_path, vocab_size) train_written = _read_existing_count(train_path) if append else 0 val_written = _read_existing_count(val_path) if (append and val_ratio > 0) else 0 rng = np.random.default_rng(seed) train_f = open(train_path, "r+b") train_f.seek(0, 2) val_f = None if val_ratio > 0: val_f = open(val_path, "r+b") val_f.seek(0, 2) total_tokens = 0 row_count = 0 log_every = 100000 try: for row in ds_tok: ids = row["input_ids"] if not ids: continue arr = np.array(ids, dtype=dtype) if val_ratio > 0 and rng.random() < val_ratio: val_f.write(arr.tobytes()) val_written += len(arr) else: train_f.write(arr.tobytes()) train_written += len(arr) total_tokens = train_written + val_written row_count += 1 if row_count % log_every == 0: print(f" {row_count:>10,} docs | {total_tokens:>12,} tokens " f"(train: {train_written:,}, val: {val_written:,})") if max_tokens and total_tokens >= max_tokens: break train_f.seek(12) train_f.write(struct.pack(" 0: val_f.seek(12) val_f.write(struct.pack(" 0 and not os.path.exists(val_path): _init_binary(val_path, vocab_size) train_written = _read_existing_count(train_path) val_written = _read_existing_count(val_path) if val_ratio > 0 else 0 else: _init_binary(train_path, vocab_size) train_written = 0 if val_ratio > 0: _init_binary(val_path, vocab_size) val_written = 0 count = 0 batch_texts = [] batch_size = 10000 log_every = 100000 rng = np.random.default_rng(seed) train_f = open(train_path, "r+b") train_f.seek(0, 2) val_f = None if val_ratio > 0: val_f = open(val_path, "r+b") val_f.seek(0, 2) def _write_ids(ids_list: list[list[int]]): nonlocal train_written, val_written for ids in ids_list: arr = np.array(ids, dtype=dtype) if val_ratio > 0 and rng.random() < val_ratio: val_f.write(arr.tobytes()) val_written += len(arr) else: train_f.write(arr.tobytes()) train_written += len(arr) print(f"\nTokenizing (streaming, batch_size={batch_size:,}, direct writes)...") exclude_files: set[str] = set() skip_errors_enabled = skip_errors and data_files is not None def _extract_bad_file(err: Exception) -> str | None: msg = str(err) match = re.search(r"file '.*::([^']+)'", msg) if match: return match.group(1) match = re.search(r"::([^'\s]+)" , msg) if match: return match.group(1) return None try: while True: if skip_errors_enabled and exclude_files: ds = _load_streaming_dataset(exclude_files) if skip_errors_enabled and count > 0: ds = ds.skip(count) try: for example in ds: if max_samples and count >= max_samples: break if max_tokens and (train_written + val_written) >= max_tokens: break text = example[text_col] if not text or not text.strip(): continue batch_texts.append(text) count += 1 if len(batch_texts) >= batch_size: encoded = tokenizer( batch_texts, add_special_tokens=False, return_attention_mask=False, return_token_type_ids=False, ) _write_ids(encoded["input_ids"]) batch_texts = [] if count % log_every == 0: total_tokens = train_written + val_written print( f" {count:>10,} samples | {total_tokens:>12,} tokens " f"(train: {train_written:,}, val: {val_written:,})" ) break except Exception as err: if not skip_errors_enabled: raise bad_file = _extract_bad_file(err) if not bad_file or bad_file in exclude_files: raise print(f" Warning: failed to read {bad_file}; skipping and resuming...") exclude_files.add(bad_file) continue # Flush remaining if batch_texts: encoded = tokenizer( batch_texts, add_special_tokens=False, return_attention_mask=False, return_token_type_ids=False, ) _write_ids(encoded["input_ids"]) train_f.seek(12) train_f.write(struct.pack(" 0: val_f.seek(12) val_f.write(struct.pack("