Mandeep Sidhu
Add reproducible WikiText corpus prep
e39c73c
"""
Derived from Andrej Karpathy's nanochat project.
MIT License
Copyright (c) 2025 Andrej Karpathy
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
"""
from __future__ import annotations
from array import array
from dataclasses import dataclass
from glob import glob
from itertools import islice
from pathlib import Path
from typing import Iterable
import numpy as np
@dataclass(frozen=True)
class TokenSplits:
train: np.ndarray
val: np.ndarray
tokenizer_path: Path
encoded_path: Path
@dataclass(frozen=True)
class CachedTokenizer:
vocab_size: int
def required_token_count(max_required_train_tokens: int, val_tokens: int) -> int:
return max(max_required_train_tokens + val_tokens, val_tokens * 10)
def resolve_paths(corpus: str | None, corpus_glob: str | None) -> list[Path]:
paths: list[Path] = []
if corpus:
paths.append(Path(corpus))
if corpus_glob:
paths.extend(Path(p) for p in sorted(glob(corpus_glob)))
if not paths:
raise ValueError("provide --corpus or --corpus-glob")
missing = [str(p) for p in paths if not p.exists()]
if missing:
raise FileNotFoundError("missing corpus paths: " + ", ".join(missing))
return paths
def load_cached_splits(
*,
cache_dir: Path,
vocab_size: int,
max_required_train_tokens: int,
val_tokens: int,
allow_short_corpus: bool,
) -> tuple[CachedTokenizer, TokenSplits]:
tokenizer_path = cache_dir / f"tokenizer-v{vocab_size}.json"
token_paths = sorted(cache_dir.glob(f"tokens-v{vocab_size}-*.npy"))
if not tokenizer_path.exists():
raise FileNotFoundError(
f"cached tokenizer not found: {tokenizer_path}. "
"Provide --corpus/--corpus-glob to build it, or copy the cache into this project."
)
if not token_paths:
raise FileNotFoundError(
f"cached encoded token file not found under {cache_dir}. "
"Expected a file like tokens-v4096-uint16.npy."
)
if len(token_paths) > 1:
preferred = cache_dir / f"tokens-v{vocab_size}-uint16.npy"
encoded_path = preferred if preferred.exists() else token_paths[0]
else:
encoded_path = token_paths[0]
tokenizer = CachedTokenizer(vocab_size=vocab_size)
tokens = np.load(encoded_path, mmap_mode="r")
need_total = required_token_count(max_required_train_tokens, val_tokens)
if len(tokens) < need_total and not allow_short_corpus:
raise ValueError(
f"cached token file has {len(tokens):,} tokens, but {need_total:,} "
"are required by this run"
)
usable_val = min(val_tokens, max(0, len(tokens) // 10))
if usable_val <= 1:
raise ValueError("not enough cached tokens for validation")
train = tokens[:-usable_val]
val = tokens[-usable_val:]
if len(train) < max_required_train_tokens and not allow_short_corpus:
raise ValueError(
f"cached train split has {len(train):,} tokens, but "
f"{max_required_train_tokens:,} are required"
)
return tokenizer, TokenSplits(
train=train,
val=val,
tokenizer_path=tokenizer_path,
encoded_path=encoded_path,
)
def iter_documents(
paths: Iterable[Path], text_column: str = "text", doc_cap_chars: int | None = None
) -> Iterable[str]:
for path in paths:
suffix = path.suffix.lower()
if suffix == ".parquet":
import pyarrow.parquet as pq
parquet_file = pq.ParquetFile(path)
for batch in parquet_file.iter_batches(columns=[text_column], batch_size=2048):
column = batch.column(0).to_pylist()
for text in column:
if text:
yield text[:doc_cap_chars] if doc_cap_chars else text
else:
text = path.read_text(encoding="utf-8")
chunk_size = doc_cap_chars or 1_000_000
for start in range(0, len(text), chunk_size):
chunk = text[start : start + chunk_size]
if chunk:
yield chunk
def train_or_load_tokenizer(
paths: list[Path],
output_dir: Path,
vocab_size: int,
tokenizer_train_chars: int,
text_column: str,
force_retrain: bool = False,
) -> object:
from dropout_decay.tokenizer import BpeTokenizer
tokenizer_path = output_dir / f"tokenizer-v{vocab_size}.json"
if tokenizer_path.exists() and not force_retrain:
return BpeTokenizer.from_file(tokenizer_path)
def capped_docs() -> Iterable[str]:
chars = 0
for doc in iter_documents(paths, text_column=text_column, doc_cap_chars=20_000):
if chars >= tokenizer_train_chars:
break
remaining = tokenizer_train_chars - chars
piece = doc[:remaining]
chars += len(piece)
yield piece
tokenizer = BpeTokenizer.train_from_iterator(capped_docs(), vocab_size)
tokenizer.save(tokenizer_path)
return tokenizer
def encode_corpus(
paths: list[Path],
tokenizer: BpeTokenizer,
output_dir: Path,
max_required_train_tokens: int,
val_tokens: int,
text_column: str,
allow_short_corpus: bool,
force_reencode: bool = False,
) -> TokenSplits:
dtype = np.uint16 if tokenizer.vocab_size <= np.iinfo(np.uint16).max else np.uint32
encoded_path = output_dir / f"tokens-v{tokenizer.vocab_size}-{dtype.__name__}.npy"
tokenizer_path = output_dir / f"tokenizer-v{tokenizer.vocab_size}.json"
need_total = required_token_count(max_required_train_tokens, val_tokens)
encode_needed = force_reencode or not encoded_path.exists()
if not encode_needed:
cached_tokens = np.load(encoded_path, mmap_mode="r")
encode_needed = len(cached_tokens) < need_total and not allow_short_corpus
if not encode_needed:
tokens = cached_tokens
if encode_needed:
ids = array("I")
for doc in iter_documents(paths, text_column=text_column, doc_cap_chars=None):
ids.extend(tokenizer.encode(doc, prepend_bos=True))
if len(ids) >= need_total:
break
tokens = np.frombuffer(ids, dtype=np.uint32).astype(dtype, copy=False)
if len(tokens) < need_total and not allow_short_corpus:
raise ValueError(
f"corpus produced {len(tokens):,} tokens, but {need_total:,} are required"
)
output_dir.mkdir(parents=True, exist_ok=True)
np.save(encoded_path, tokens)
tokens = np.load(encoded_path, mmap_mode="r")
usable_val = min(val_tokens, max(0, len(tokens) // 10))
if usable_val <= 1:
raise ValueError("not enough tokens for validation")
train = tokens[:-usable_val]
val = tokens[-usable_val:]
if len(train) < max_required_train_tokens and not allow_short_corpus:
raise ValueError(
f"train split has {len(train):,} tokens, but {max_required_train_tokens:,} are required"
)
return TokenSplits(
train=train,
val=val,
tokenizer_path=tokenizer_path,
encoded_path=encoded_path,
)
def preview_documents(paths: list[Path], limit: int = 3) -> list[str]:
return list(islice(iter_documents(paths), limit))