| """ |
| Derived from Andrej Karpathy's nanochat project. |
| |
| MIT License |
| |
| Copyright (c) 2025 Andrej Karpathy |
| |
| Permission is hereby granted, free of charge, to any person obtaining a copy |
| of this software and associated documentation files (the "Software"), to deal |
| in the Software without restriction, including without limitation the rights |
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
| copies of the Software, and to permit persons to whom the Software is |
| furnished to do so, subject to the following conditions: |
| |
| The above copyright notice and this permission notice shall be included in all |
| copies or substantial portions of the Software. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from array import array |
| from dataclasses import dataclass |
| from glob import glob |
| from itertools import islice |
| from pathlib import Path |
| from typing import Iterable |
|
|
| import numpy as np |
|
|
|
|
| @dataclass(frozen=True) |
| class TokenSplits: |
| train: np.ndarray |
| val: np.ndarray |
| tokenizer_path: Path |
| encoded_path: Path |
|
|
|
|
| @dataclass(frozen=True) |
| class CachedTokenizer: |
| vocab_size: int |
|
|
|
|
| def required_token_count(max_required_train_tokens: int, val_tokens: int) -> int: |
| return max(max_required_train_tokens + val_tokens, val_tokens * 10) |
|
|
|
|
| def resolve_paths(corpus: str | None, corpus_glob: str | None) -> list[Path]: |
| paths: list[Path] = [] |
| if corpus: |
| paths.append(Path(corpus)) |
| if corpus_glob: |
| paths.extend(Path(p) for p in sorted(glob(corpus_glob))) |
| if not paths: |
| raise ValueError("provide --corpus or --corpus-glob") |
| missing = [str(p) for p in paths if not p.exists()] |
| if missing: |
| raise FileNotFoundError("missing corpus paths: " + ", ".join(missing)) |
| return paths |
|
|
|
|
| def load_cached_splits( |
| *, |
| cache_dir: Path, |
| vocab_size: int, |
| max_required_train_tokens: int, |
| val_tokens: int, |
| allow_short_corpus: bool, |
| ) -> tuple[CachedTokenizer, TokenSplits]: |
| tokenizer_path = cache_dir / f"tokenizer-v{vocab_size}.json" |
| token_paths = sorted(cache_dir.glob(f"tokens-v{vocab_size}-*.npy")) |
| if not tokenizer_path.exists(): |
| raise FileNotFoundError( |
| f"cached tokenizer not found: {tokenizer_path}. " |
| "Provide --corpus/--corpus-glob to build it, or copy the cache into this project." |
| ) |
| if not token_paths: |
| raise FileNotFoundError( |
| f"cached encoded token file not found under {cache_dir}. " |
| "Expected a file like tokens-v4096-uint16.npy." |
| ) |
| if len(token_paths) > 1: |
| preferred = cache_dir / f"tokens-v{vocab_size}-uint16.npy" |
| encoded_path = preferred if preferred.exists() else token_paths[0] |
| else: |
| encoded_path = token_paths[0] |
|
|
| tokenizer = CachedTokenizer(vocab_size=vocab_size) |
| tokens = np.load(encoded_path, mmap_mode="r") |
| need_total = required_token_count(max_required_train_tokens, val_tokens) |
| if len(tokens) < need_total and not allow_short_corpus: |
| raise ValueError( |
| f"cached token file has {len(tokens):,} tokens, but {need_total:,} " |
| "are required by this run" |
| ) |
|
|
| usable_val = min(val_tokens, max(0, len(tokens) // 10)) |
| if usable_val <= 1: |
| raise ValueError("not enough cached tokens for validation") |
| train = tokens[:-usable_val] |
| val = tokens[-usable_val:] |
| if len(train) < max_required_train_tokens and not allow_short_corpus: |
| raise ValueError( |
| f"cached train split has {len(train):,} tokens, but " |
| f"{max_required_train_tokens:,} are required" |
| ) |
| return tokenizer, TokenSplits( |
| train=train, |
| val=val, |
| tokenizer_path=tokenizer_path, |
| encoded_path=encoded_path, |
| ) |
|
|
|
|
| def iter_documents( |
| paths: Iterable[Path], text_column: str = "text", doc_cap_chars: int | None = None |
| ) -> Iterable[str]: |
| for path in paths: |
| suffix = path.suffix.lower() |
| if suffix == ".parquet": |
| import pyarrow.parquet as pq |
|
|
| parquet_file = pq.ParquetFile(path) |
| for batch in parquet_file.iter_batches(columns=[text_column], batch_size=2048): |
| column = batch.column(0).to_pylist() |
| for text in column: |
| if text: |
| yield text[:doc_cap_chars] if doc_cap_chars else text |
| else: |
| text = path.read_text(encoding="utf-8") |
| chunk_size = doc_cap_chars or 1_000_000 |
| for start in range(0, len(text), chunk_size): |
| chunk = text[start : start + chunk_size] |
| if chunk: |
| yield chunk |
|
|
|
|
| def train_or_load_tokenizer( |
| paths: list[Path], |
| output_dir: Path, |
| vocab_size: int, |
| tokenizer_train_chars: int, |
| text_column: str, |
| force_retrain: bool = False, |
| ) -> object: |
| from dropout_decay.tokenizer import BpeTokenizer |
|
|
| tokenizer_path = output_dir / f"tokenizer-v{vocab_size}.json" |
| if tokenizer_path.exists() and not force_retrain: |
| return BpeTokenizer.from_file(tokenizer_path) |
|
|
| def capped_docs() -> Iterable[str]: |
| chars = 0 |
| for doc in iter_documents(paths, text_column=text_column, doc_cap_chars=20_000): |
| if chars >= tokenizer_train_chars: |
| break |
| remaining = tokenizer_train_chars - chars |
| piece = doc[:remaining] |
| chars += len(piece) |
| yield piece |
|
|
| tokenizer = BpeTokenizer.train_from_iterator(capped_docs(), vocab_size) |
| tokenizer.save(tokenizer_path) |
| return tokenizer |
|
|
|
|
| def encode_corpus( |
| paths: list[Path], |
| tokenizer: BpeTokenizer, |
| output_dir: Path, |
| max_required_train_tokens: int, |
| val_tokens: int, |
| text_column: str, |
| allow_short_corpus: bool, |
| force_reencode: bool = False, |
| ) -> TokenSplits: |
| dtype = np.uint16 if tokenizer.vocab_size <= np.iinfo(np.uint16).max else np.uint32 |
| encoded_path = output_dir / f"tokens-v{tokenizer.vocab_size}-{dtype.__name__}.npy" |
| tokenizer_path = output_dir / f"tokenizer-v{tokenizer.vocab_size}.json" |
| need_total = required_token_count(max_required_train_tokens, val_tokens) |
| encode_needed = force_reencode or not encoded_path.exists() |
| if not encode_needed: |
| cached_tokens = np.load(encoded_path, mmap_mode="r") |
| encode_needed = len(cached_tokens) < need_total and not allow_short_corpus |
| if not encode_needed: |
| tokens = cached_tokens |
|
|
| if encode_needed: |
| ids = array("I") |
| for doc in iter_documents(paths, text_column=text_column, doc_cap_chars=None): |
| ids.extend(tokenizer.encode(doc, prepend_bos=True)) |
| if len(ids) >= need_total: |
| break |
| tokens = np.frombuffer(ids, dtype=np.uint32).astype(dtype, copy=False) |
| if len(tokens) < need_total and not allow_short_corpus: |
| raise ValueError( |
| f"corpus produced {len(tokens):,} tokens, but {need_total:,} are required" |
| ) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| np.save(encoded_path, tokens) |
| tokens = np.load(encoded_path, mmap_mode="r") |
|
|
| usable_val = min(val_tokens, max(0, len(tokens) // 10)) |
| if usable_val <= 1: |
| raise ValueError("not enough tokens for validation") |
| train = tokens[:-usable_val] |
| val = tokens[-usable_val:] |
| if len(train) < max_required_train_tokens and not allow_short_corpus: |
| raise ValueError( |
| f"train split has {len(train):,} tokens, but {max_required_train_tokens:,} are required" |
| ) |
| return TokenSplits( |
| train=train, |
| val=val, |
| tokenizer_path=tokenizer_path, |
| encoded_path=encoded_path, |
| ) |
|
|
|
|
| def preview_documents(paths: list[Path], limit: int = 3) -> list[str]: |
| return list(islice(iter_documents(paths), limit)) |
|
|