FreqFormer-32M / freqformer /tokenize_data.py
cturan's picture
Upload folder using huggingface_hub
01fcd60 verified
"""
Dataset Tokenizer
=================
Tokenizes datasets and saves as binary files for mmap loading.
Usage:
# HuggingFace dataset
python -m freqformer.tokenize_data --dataset wikitext --subset wikitext-2-raw-v1 --out data/wiki2
# Custom text files
python -m freqformer.tokenize_data --files train.txt val.txt --out data/custom
# Append more data to existing binary files (does not overwrite)
python -m freqformer.tokenize_data --dataset openwebtext --out data/mixed --append
python -m freqformer.tokenize_data --dataset allenai/dolma3_pool --out data/mixed --append --streaming
"""
import argparse
import os
import re
import struct
import numpy as np
from pathlib import Path
from transformers import AutoTokenizer
HEADER_MAGIC = b"FREQTOK1" # 8 bytes magic
HEADER_SIZE = 32 # magic(8) + version(4) + num_tokens(8) + seq_len(4) + vocab_size(4) + reserved(4)
def write_binary(tokens: list[int], path: str, vocab_size: int):
"""Write tokenized data as binary file with header."""
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
arr = np.array(tokens, dtype=np.uint16 if vocab_size < 65536 else np.uint32)
with open(path, "wb") as f:
# Header
f.write(HEADER_MAGIC)
f.write(struct.pack("<I", 1)) # version
f.write(struct.pack("<Q", len(tokens))) # num_tokens
f.write(struct.pack("<I", 0)) # seq_len (0 = not chunked)
f.write(struct.pack("<I", vocab_size)) # vocab_size
f.write(struct.pack("<I", 0)) # reserved
# Data
f.write(arr.tobytes())
size_mb = os.path.getsize(path) / 1024 / 1024
print(f" Saved {path}: {len(tokens):,} tokens ({size_mb:.1f} MB)")
def append_binary(tokens: list[int], path: str, vocab_size: int):
"""Append tokens to an existing binary file, or create if it doesn't exist."""
if not os.path.exists(path):
write_binary(tokens, path, vocab_size)
return
# Read existing header
with open(path, "rb") as f:
magic = f.read(8)
if magic != HEADER_MAGIC:
raise ValueError(f"Cannot append: {path} is not a valid FREQTOK1 file")
_version = struct.unpack("<I", f.read(4))[0]
old_num_tokens = struct.unpack("<Q", f.read(8))[0]
_seq_len = struct.unpack("<I", f.read(4))[0]
old_vocab_size = struct.unpack("<I", f.read(4))[0]
if old_vocab_size != vocab_size:
raise ValueError(
f"Cannot append: vocab_size mismatch ({old_vocab_size} in file vs {vocab_size} new). "
f"Use the same tokenizer."
)
dtype = np.uint16 if old_vocab_size < 65536 else np.uint32
max_id = max(tokens) if tokens else 0
if dtype == np.uint16 and max_id >= 65535:
raise ValueError(
f"Cannot append: token ID {max_id} exceeds uint16 range but existing file uses uint16. "
f"Recreate the file with a larger vocab_size tokenizer."
)
arr = np.array(tokens, dtype=dtype)
new_total = old_num_tokens + len(tokens)
# Append data and update header
with open(path, "r+b") as f:
# Update num_tokens in header (offset 12 = magic(8) + version(4))
f.seek(12)
f.write(struct.pack("<Q", new_total))
# Append data at end
f.seek(0, 2) # seek to end
f.write(arr.tobytes())
size_mb = os.path.getsize(path) / 1024 / 1024
print(f" Appended to {path}: +{len(tokens):,} tokens (total: {new_total:,}, {size_mb:.1f} MB)")
def _init_binary(path: str, vocab_size: int):
"""Create a fresh binary file with header (zero tokens). Returns dtype."""
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
dtype = np.uint16 if vocab_size < 65536 else np.uint32
with open(path, "wb") as f:
f.write(HEADER_MAGIC)
f.write(struct.pack("<I", 1)) # version
f.write(struct.pack("<Q", 0)) # num_tokens (updated later)
f.write(struct.pack("<I", 0)) # seq_len
f.write(struct.pack("<I", vocab_size)) # vocab_size
f.write(struct.pack("<I", 0)) # reserved
return dtype
def tokenize_hf_dataset(
dataset_name: str, subset: str, tokenizer_name: str, out_dir: str,
max_samples: int = 0, split: str = "train", val_ratio: float = 0.02,
streaming: bool = False, append: bool = False, num_proc: int = 0,
max_tokens: int = 0, seed: int = 1234, skip_errors: bool = False,
):
"""Tokenize a HuggingFace dataset.
Non-streaming: uses datasets.map() with num_proc for parallel tokenization.
Streaming: batched tokenization with periodic disk flushes (low RAM).
"""
from datasets import load_dataset
import multiprocessing
label = f"{dataset_name}/{subset}" if subset else dataset_name
print(f"Loading dataset: {label} (split={split}, streaming={streaming})")
if max_samples:
print(f" Max samples: {max_samples:,}")
if max_tokens:
print(f" Max tokens: {max_tokens:,}")
ds_kwargs = {}
if subset:
ds_kwargs["name"] = subset
if streaming:
ds_kwargs["streaming"] = True
data_files = None
dataset_path = Path(dataset_name)
if streaming and dataset_path.exists() and dataset_path.is_dir():
patterns = ["*.jsonl", "*.jsonl.zst", "*.json", "*.json.gz"]
files = []
for pattern in patterns:
files.extend(dataset_path.rglob(pattern))
if files:
data_files = sorted({str(p) for p in files})
def _load_streaming_dataset(exclude_files: set[str] | None = None):
if data_files:
if exclude_files:
files = [f for f in data_files if f not in exclude_files]
else:
files = data_files
if not files:
raise ValueError("All data files were excluded; cannot continue streaming.")
return load_dataset("json", data_files=files, split=split, streaming=True)
return load_dataset(dataset_name, split=split, **ds_kwargs)
ds = _load_streaming_dataset()
print(f"Loading tokenizer: {tokenizer_name}")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
vocab_size = len(tokenizer) # includes added special tokens
if num_proc <= 0:
num_proc = min(multiprocessing.cpu_count(), 16)
# Detect text column
if streaming:
first = next(iter(ds))
text_col = "text" if "text" in first else list(first.keys())[0]
else:
text_col = "text" if "text" in ds.column_names else ds.column_names[0]
print(f" text_col='{text_col}', vocab_size={vocab_size}, num_proc={num_proc}")
def _read_existing_count(path: str) -> int:
with open(path, "rb") as f:
magic = f.read(8)
if magic != HEADER_MAGIC:
raise ValueError(f"{path} is not a valid FREQTOK1 file")
_version = struct.unpack("<I", f.read(4))[0]
num_tokens = struct.unpack("<Q", f.read(8))[0]
_seq_len = struct.unpack("<I", f.read(4))[0]
file_vocab = struct.unpack("<I", f.read(4))[0]
if file_vocab != vocab_size:
raise ValueError(
f"Cannot append: vocab_size mismatch ({file_vocab} in file vs {vocab_size} new). "
f"Use the same tokenizer."
)
return num_tokens
# ---- Non-streaming: parallel tokenization via datasets.map() ----
if not streaming:
if max_samples:
ds = ds.select(range(min(max_samples, len(ds))))
def _tok_batch(batch):
encoded = tokenizer(batch[text_col], add_special_tokens=False)
return {"input_ids": encoded["input_ids"]}
print(f"\nTokenizing with {num_proc} processes...")
ds_tok = ds.map(
_tok_batch, batched=True, batch_size=10000,
num_proc=num_proc, remove_columns=ds.column_names,
desc="Tokenizing",
)
# Direct disk writes (constant RAM)
dtype = np.uint16 if vocab_size < 65536 else np.uint32
train_path = os.path.join(out_dir, "train.bin")
val_path = os.path.join(out_dir, "val.bin")
if not append:
_init_binary(train_path, vocab_size)
if val_ratio > 0:
_init_binary(val_path, vocab_size)
else:
if not os.path.exists(train_path):
_init_binary(train_path, vocab_size)
if val_ratio > 0 and not os.path.exists(val_path):
_init_binary(val_path, vocab_size)
train_written = _read_existing_count(train_path) if append else 0
val_written = _read_existing_count(val_path) if (append and val_ratio > 0) else 0
rng = np.random.default_rng(seed)
train_f = open(train_path, "r+b")
train_f.seek(0, 2)
val_f = None
if val_ratio > 0:
val_f = open(val_path, "r+b")
val_f.seek(0, 2)
total_tokens = 0
row_count = 0
log_every = 100000
try:
for row in ds_tok:
ids = row["input_ids"]
if not ids:
continue
arr = np.array(ids, dtype=dtype)
if val_ratio > 0 and rng.random() < val_ratio:
val_f.write(arr.tobytes())
val_written += len(arr)
else:
train_f.write(arr.tobytes())
train_written += len(arr)
total_tokens = train_written + val_written
row_count += 1
if row_count % log_every == 0:
print(f" {row_count:>10,} docs | {total_tokens:>12,} tokens "
f"(train: {train_written:,}, val: {val_written:,})")
if max_tokens and total_tokens >= max_tokens:
break
train_f.seek(12)
train_f.write(struct.pack("<Q", train_written))
if val_ratio > 0:
val_f.seek(12)
val_f.write(struct.pack("<Q", val_written))
finally:
train_f.close()
if val_f is not None:
val_f.close()
print(f" Total: {train_written + val_written:,} tokens")
print(f" Train: {train_written:,} tokens, Val: {val_written:,} tokens")
return
# ---- Streaming: per-batch disk writes (constant RAM) ----
dtype = np.uint16 if vocab_size < 65536 else np.uint32
train_path = os.path.join(out_dir, "train.bin")
val_path = os.path.join(out_dir, "val.bin")
if append:
if not os.path.exists(train_path):
_init_binary(train_path, vocab_size)
if val_ratio > 0 and not os.path.exists(val_path):
_init_binary(val_path, vocab_size)
train_written = _read_existing_count(train_path)
val_written = _read_existing_count(val_path) if val_ratio > 0 else 0
else:
_init_binary(train_path, vocab_size)
train_written = 0
if val_ratio > 0:
_init_binary(val_path, vocab_size)
val_written = 0
count = 0
batch_texts = []
batch_size = 10000
log_every = 100000
rng = np.random.default_rng(seed)
train_f = open(train_path, "r+b")
train_f.seek(0, 2)
val_f = None
if val_ratio > 0:
val_f = open(val_path, "r+b")
val_f.seek(0, 2)
def _write_ids(ids_list: list[list[int]]):
nonlocal train_written, val_written
for ids in ids_list:
arr = np.array(ids, dtype=dtype)
if val_ratio > 0 and rng.random() < val_ratio:
val_f.write(arr.tobytes())
val_written += len(arr)
else:
train_f.write(arr.tobytes())
train_written += len(arr)
print(f"\nTokenizing (streaming, batch_size={batch_size:,}, direct writes)...")
exclude_files: set[str] = set()
skip_errors_enabled = skip_errors and data_files is not None
def _extract_bad_file(err: Exception) -> str | None:
msg = str(err)
match = re.search(r"file '.*::([^']+)'", msg)
if match:
return match.group(1)
match = re.search(r"::([^'\s]+)" , msg)
if match:
return match.group(1)
return None
try:
while True:
if skip_errors_enabled and exclude_files:
ds = _load_streaming_dataset(exclude_files)
if skip_errors_enabled and count > 0:
ds = ds.skip(count)
try:
for example in ds:
if max_samples and count >= max_samples:
break
if max_tokens and (train_written + val_written) >= max_tokens:
break
text = example[text_col]
if not text or not text.strip():
continue
batch_texts.append(text)
count += 1
if len(batch_texts) >= batch_size:
encoded = tokenizer(
batch_texts,
add_special_tokens=False,
return_attention_mask=False,
return_token_type_ids=False,
)
_write_ids(encoded["input_ids"])
batch_texts = []
if count % log_every == 0:
total_tokens = train_written + val_written
print(
f" {count:>10,} samples | {total_tokens:>12,} tokens "
f"(train: {train_written:,}, val: {val_written:,})"
)
break
except Exception as err:
if not skip_errors_enabled:
raise
bad_file = _extract_bad_file(err)
if not bad_file or bad_file in exclude_files:
raise
print(f" Warning: failed to read {bad_file}; skipping and resuming...")
exclude_files.add(bad_file)
continue
# Flush remaining
if batch_texts:
encoded = tokenizer(
batch_texts,
add_special_tokens=False,
return_attention_mask=False,
return_token_type_ids=False,
)
_write_ids(encoded["input_ids"])
train_f.seek(12)
train_f.write(struct.pack("<Q", train_written))
if val_ratio > 0:
val_f.seek(12)
val_f.write(struct.pack("<Q", val_written))
finally:
train_f.close()
if val_f is not None:
val_f.close()
print(f" Total: {count:,} samples, {train_written + val_written:,} tokens written")
print(f" Train: {train_written:,} tokens, Val: {val_written:,} tokens")
def tokenize_files(files: list[str], tokenizer_name: str, out_dir: str, append: bool = False):
"""Tokenize raw text files."""
print(f"Loading tokenizer: {tokenizer_name}")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
vocab_size = len(tokenizer) # includes added special tokens
write_fn = append_binary if append else write_binary
for fpath in files:
print(f"\nTokenizing {fpath}...")
with open(fpath, "r", encoding="utf-8") as f:
text = f.read()
tokens = tokenizer.encode(text, add_special_tokens=False)
stem = Path(fpath).stem
write_fn(tokens, os.path.join(out_dir, f"{stem}.bin"), vocab_size)
def main():
parser = argparse.ArgumentParser(description="Tokenize datasets for FreqFormer training")
parser.add_argument("--dataset", type=str, default=None, help="HuggingFace dataset name")
parser.add_argument("--subset", type=str, default=None, help="Dataset subset/config")
parser.add_argument("--files", nargs="+", default=None, help="Raw text files to tokenize")
parser.add_argument("--tokenizer", type=str, default="freqformer/tokenizer", help="Tokenizer name/path")
parser.add_argument("--out", type=str, required=True, help="Output directory")
parser.add_argument("--max_samples", type=int, default=0, help="Max samples to take (0=all)")
parser.add_argument("--split", type=str, default="train", help="Dataset split to use")
parser.add_argument("--val_ratio", type=float, default=0.02, help="Fraction for validation split")
parser.add_argument("--streaming", action="store_true", help="Use streaming mode for large datasets")
parser.add_argument("--append", action="store_true", help="Append to existing binary files instead of overwriting")
parser.add_argument("--num_proc", type=int, default=0, help="Number of CPU processes for tokenization (0=auto)")
parser.add_argument("--max_tokens", type=int, default=0, help="Max tokens to produce (0=all)")
parser.add_argument("--seed", type=int, default=1234, help="RNG seed for streaming val split")
parser.add_argument("--skip_errors", action="store_true", help="Skip corrupted JSON files in streaming mode")
args = parser.parse_args()
if args.append:
print("Mode: APPEND (will add to existing data files)")
if args.dataset:
tokenize_hf_dataset(
args.dataset, args.subset, args.tokenizer, args.out,
max_samples=args.max_samples, split=args.split,
val_ratio=args.val_ratio, streaming=args.streaming,
append=args.append, num_proc=args.num_proc,
max_tokens=args.max_tokens,
seed=args.seed,
skip_errors=args.skip_errors,
)
elif args.files:
tokenize_files(args.files, args.tokenizer, args.out, append=args.append)
else:
parser.error("Specify --dataset or --files")
print(f"\nDone! Data saved to {args.out}/")
print(f"Use --data_dir {args.out} when training.")
if __name__ == "__main__":
main()