""" Derived from Andrej Karpathy's nanochat project. MIT License Copyright (c) 2025 Andrej Karpathy Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. """ from __future__ import annotations from array import array from dataclasses import dataclass from glob import glob from itertools import islice from pathlib import Path from typing import Iterable import numpy as np @dataclass(frozen=True) class TokenSplits: train: np.ndarray val: np.ndarray tokenizer_path: Path encoded_path: Path @dataclass(frozen=True) class CachedTokenizer: vocab_size: int def required_token_count(max_required_train_tokens: int, val_tokens: int) -> int: return max(max_required_train_tokens + val_tokens, val_tokens * 10) def resolve_paths(corpus: str | None, corpus_glob: str | None) -> list[Path]: paths: list[Path] = [] if corpus: paths.append(Path(corpus)) if corpus_glob: paths.extend(Path(p) for p in sorted(glob(corpus_glob))) if not paths: raise ValueError("provide --corpus or --corpus-glob") missing = [str(p) for p in paths if not p.exists()] if missing: raise FileNotFoundError("missing corpus paths: " + ", ".join(missing)) return paths def load_cached_splits( *, cache_dir: Path, vocab_size: int, max_required_train_tokens: int, val_tokens: int, allow_short_corpus: bool, ) -> tuple[CachedTokenizer, TokenSplits]: tokenizer_path = cache_dir / f"tokenizer-v{vocab_size}.json" token_paths = sorted(cache_dir.glob(f"tokens-v{vocab_size}-*.npy")) if not tokenizer_path.exists(): raise FileNotFoundError( f"cached tokenizer not found: {tokenizer_path}. " "Provide --corpus/--corpus-glob to build it, or copy the cache into this project." ) if not token_paths: raise FileNotFoundError( f"cached encoded token file not found under {cache_dir}. " "Expected a file like tokens-v4096-uint16.npy." ) if len(token_paths) > 1: preferred = cache_dir / f"tokens-v{vocab_size}-uint16.npy" encoded_path = preferred if preferred.exists() else token_paths[0] else: encoded_path = token_paths[0] tokenizer = CachedTokenizer(vocab_size=vocab_size) tokens = np.load(encoded_path, mmap_mode="r") need_total = required_token_count(max_required_train_tokens, val_tokens) if len(tokens) < need_total and not allow_short_corpus: raise ValueError( f"cached token file has {len(tokens):,} tokens, but {need_total:,} " "are required by this run" ) usable_val = min(val_tokens, max(0, len(tokens) // 10)) if usable_val <= 1: raise ValueError("not enough cached tokens for validation") train = tokens[:-usable_val] val = tokens[-usable_val:] if len(train) < max_required_train_tokens and not allow_short_corpus: raise ValueError( f"cached train split has {len(train):,} tokens, but " f"{max_required_train_tokens:,} are required" ) return tokenizer, TokenSplits( train=train, val=val, tokenizer_path=tokenizer_path, encoded_path=encoded_path, ) def iter_documents( paths: Iterable[Path], text_column: str = "text", doc_cap_chars: int | None = None ) -> Iterable[str]: for path in paths: suffix = path.suffix.lower() if suffix == ".parquet": import pyarrow.parquet as pq parquet_file = pq.ParquetFile(path) for batch in parquet_file.iter_batches(columns=[text_column], batch_size=2048): column = batch.column(0).to_pylist() for text in column: if text: yield text[:doc_cap_chars] if doc_cap_chars else text else: text = path.read_text(encoding="utf-8") chunk_size = doc_cap_chars or 1_000_000 for start in range(0, len(text), chunk_size): chunk = text[start : start + chunk_size] if chunk: yield chunk def train_or_load_tokenizer( paths: list[Path], output_dir: Path, vocab_size: int, tokenizer_train_chars: int, text_column: str, force_retrain: bool = False, ) -> object: from dropout_decay.tokenizer import BpeTokenizer tokenizer_path = output_dir / f"tokenizer-v{vocab_size}.json" if tokenizer_path.exists() and not force_retrain: return BpeTokenizer.from_file(tokenizer_path) def capped_docs() -> Iterable[str]: chars = 0 for doc in iter_documents(paths, text_column=text_column, doc_cap_chars=20_000): if chars >= tokenizer_train_chars: break remaining = tokenizer_train_chars - chars piece = doc[:remaining] chars += len(piece) yield piece tokenizer = BpeTokenizer.train_from_iterator(capped_docs(), vocab_size) tokenizer.save(tokenizer_path) return tokenizer def encode_corpus( paths: list[Path], tokenizer: BpeTokenizer, output_dir: Path, max_required_train_tokens: int, val_tokens: int, text_column: str, allow_short_corpus: bool, force_reencode: bool = False, ) -> TokenSplits: dtype = np.uint16 if tokenizer.vocab_size <= np.iinfo(np.uint16).max else np.uint32 encoded_path = output_dir / f"tokens-v{tokenizer.vocab_size}-{dtype.__name__}.npy" tokenizer_path = output_dir / f"tokenizer-v{tokenizer.vocab_size}.json" need_total = required_token_count(max_required_train_tokens, val_tokens) encode_needed = force_reencode or not encoded_path.exists() if not encode_needed: cached_tokens = np.load(encoded_path, mmap_mode="r") encode_needed = len(cached_tokens) < need_total and not allow_short_corpus if not encode_needed: tokens = cached_tokens if encode_needed: ids = array("I") for doc in iter_documents(paths, text_column=text_column, doc_cap_chars=None): ids.extend(tokenizer.encode(doc, prepend_bos=True)) if len(ids) >= need_total: break tokens = np.frombuffer(ids, dtype=np.uint32).astype(dtype, copy=False) if len(tokens) < need_total and not allow_short_corpus: raise ValueError( f"corpus produced {len(tokens):,} tokens, but {need_total:,} are required" ) output_dir.mkdir(parents=True, exist_ok=True) np.save(encoded_path, tokens) tokens = np.load(encoded_path, mmap_mode="r") usable_val = min(val_tokens, max(0, len(tokens) // 10)) if usable_val <= 1: raise ValueError("not enough tokens for validation") train = tokens[:-usable_val] val = tokens[-usable_val:] if len(train) < max_required_train_tokens and not allow_short_corpus: raise ValueError( f"train split has {len(train):,} tokens, but {max_required_train_tokens:,} are required" ) return TokenSplits( train=train, val=val, tokenizer_path=tokenizer_path, encoded_path=encoded_path, ) def preview_documents(paths: list[Path], limit: int = 3) -> list[str]: return list(islice(iter_documents(paths), limit))