| """ |
| One-time data preparation for autoresearch experiments. |
| Downloads data shards and trains a BPE tokenizer. |
| |
| Usage: |
| python prepare.py # full prep (download + tokenizer) |
| python prepare.py --num-shards 8 # download only 8 shards (for testing) |
| |
| Data and tokenizer are stored in ~/.cache/autoresearch/. |
| """ |
|
|
| import os |
| import sys |
| import time |
| import math |
| import argparse |
| import pickle |
| from multiprocessing import Pool |
|
|
| import requests |
| import pyarrow.parquet as pq |
| import rustbpe |
| import tiktoken |
| import torch |
|
|
| |
| |
| |
|
|
| MAX_SEQ_LEN = int(os.environ.get("HYDRA_SEQ_LEN", "512")) |
| TIME_BUDGET = 300 |
| EVAL_TOKENS = 40 * 524288 |
|
|
| |
| |
| |
|
|
| CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "autoresearch") |
| DATA_DIR = os.path.join(CACHE_DIR, "data") |
| TOKENIZER_DIR = os.path.join(CACHE_DIR, "tokenizer") |
| BASE_URL = "https://huggingface.co/datasets/karpathy/climbmix-400b-shuffle/resolve/main" |
| MAX_SHARD = 6542 |
| VAL_SHARD = MAX_SHARD |
| VAL_FILENAME = f"shard_{VAL_SHARD:05d}.parquet" |
| VOCAB_SIZE = int(os.environ.get("HYDRA_VOCAB_SIZE", "65536")) |
|
|
| |
| SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" |
|
|
| SPECIAL_TOKENS = [f"<|reserved_{i}|>" for i in range(4)] |
| BOS_TOKEN = "<|reserved_0|>" |
|
|
| |
| |
| |
|
|
| def download_single_shard(index): |
| """Download one parquet shard with retries. Returns True on success.""" |
| filename = f"shard_{index:05d}.parquet" |
| filepath = os.path.join(DATA_DIR, filename) |
| if os.path.exists(filepath): |
| return True |
|
|
| url = f"{BASE_URL}/{filename}" |
| max_attempts = 5 |
| for attempt in range(1, max_attempts + 1): |
| try: |
| response = requests.get(url, stream=True, timeout=30) |
| response.raise_for_status() |
| temp_path = filepath + ".tmp" |
| with open(temp_path, "wb") as f: |
| for chunk in response.iter_content(chunk_size=1024 * 1024): |
| if chunk: |
| f.write(chunk) |
| os.rename(temp_path, filepath) |
| print(f" Downloaded {filename}") |
| return True |
| except (requests.RequestException, IOError) as e: |
| print(f" Attempt {attempt}/{max_attempts} failed for {filename}: {e}") |
| for path in [filepath + ".tmp", filepath]: |
| if os.path.exists(path): |
| try: |
| os.remove(path) |
| except OSError: |
| pass |
| if attempt < max_attempts: |
| time.sleep(2 ** attempt) |
| return False |
|
|
|
|
| def download_data(num_shards, download_workers=8): |
| """Download training shards + pinned validation shard.""" |
| os.makedirs(DATA_DIR, exist_ok=True) |
| num_train = min(num_shards, MAX_SHARD) |
| ids = list(range(num_train)) |
| if VAL_SHARD not in ids: |
| ids.append(VAL_SHARD) |
|
|
| |
| existing = sum(1 for i in ids if os.path.exists(os.path.join(DATA_DIR, f"shard_{i:05d}.parquet"))) |
| if existing == len(ids): |
| print(f"Data: all {len(ids)} shards already downloaded at {DATA_DIR}") |
| return |
|
|
| needed = len(ids) - existing |
| print(f"Data: downloading {needed} shards ({existing} already exist)...") |
|
|
| workers = max(1, min(download_workers, needed)) |
| with Pool(processes=workers) as pool: |
| results = pool.map(download_single_shard, ids) |
|
|
| ok = sum(1 for r in results if r) |
| print(f"Data: {ok}/{len(ids)} shards ready at {DATA_DIR}") |
|
|
| |
| |
| |
|
|
| def list_parquet_files(): |
| """Return sorted list of parquet file paths in the data directory.""" |
| files = sorted(f for f in os.listdir(DATA_DIR) if f.endswith(".parquet") and not f.endswith(".tmp")) |
| return [os.path.join(DATA_DIR, f) for f in files] |
|
|
|
|
| def text_iterator(max_chars=1_000_000_000, doc_cap=10_000): |
| """Yield documents from training split (all shards except pinned val shard).""" |
| parquet_paths = [p for p in list_parquet_files() if not p.endswith(VAL_FILENAME)] |
| nchars = 0 |
| for filepath in parquet_paths: |
| pf = pq.ParquetFile(filepath) |
| for rg_idx in range(pf.num_row_groups): |
| rg = pf.read_row_group(rg_idx) |
| for text in rg.column("text").to_pylist(): |
| doc = text[:doc_cap] if len(text) > doc_cap else text |
| nchars += len(doc) |
| yield doc |
| if nchars >= max_chars: |
| return |
|
|
|
|
| def train_tokenizer(): |
| """Train BPE tokenizer using rustbpe, save as tiktoken pickle.""" |
| tokenizer_pkl = os.path.join(TOKENIZER_DIR, "tokenizer.pkl") |
| token_bytes_path = os.path.join(TOKENIZER_DIR, "token_bytes.pt") |
|
|
| if os.path.exists(tokenizer_pkl) and os.path.exists(token_bytes_path): |
| print(f"Tokenizer: already trained at {TOKENIZER_DIR}") |
| return |
|
|
| os.makedirs(TOKENIZER_DIR, exist_ok=True) |
|
|
| parquet_files = list_parquet_files() |
| if len(parquet_files) < 2: |
| print("Tokenizer: need at least 2 data shards (1 train + 1 val). Download more data first.") |
| sys.exit(1) |
|
|
| |
| print("Tokenizer: training BPE tokenizer...") |
| t0 = time.time() |
|
|
| tokenizer = rustbpe.Tokenizer() |
| vocab_size_no_special = VOCAB_SIZE - len(SPECIAL_TOKENS) |
| tokenizer.train_from_iterator(text_iterator(), vocab_size_no_special, pattern=SPLIT_PATTERN) |
|
|
| |
| pattern = tokenizer.get_pattern() |
| mergeable_ranks = {bytes(k): v for k, v in tokenizer.get_mergeable_ranks()} |
| tokens_offset = len(mergeable_ranks) |
| special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)} |
| enc = tiktoken.Encoding( |
| name="rustbpe", |
| pat_str=pattern, |
| mergeable_ranks=mergeable_ranks, |
| special_tokens=special_tokens, |
| ) |
|
|
| |
| with open(tokenizer_pkl, "wb") as f: |
| pickle.dump(enc, f) |
|
|
| t1 = time.time() |
| print(f"Tokenizer: trained in {t1 - t0:.1f}s, saved to {tokenizer_pkl}") |
|
|
| |
| print("Tokenizer: building token_bytes lookup...") |
| special_set = set(SPECIAL_TOKENS) |
| token_bytes_list = [] |
| for token_id in range(enc.n_vocab): |
| token_str = enc.decode([token_id]) |
| if token_str in special_set: |
| token_bytes_list.append(0) |
| else: |
| token_bytes_list.append(len(token_str.encode("utf-8"))) |
| token_bytes_tensor = torch.tensor(token_bytes_list, dtype=torch.int32) |
| torch.save(token_bytes_tensor, token_bytes_path) |
| print(f"Tokenizer: saved token_bytes to {token_bytes_path}") |
|
|
| |
| test = "Hello world! Numbers: 123. Unicode: 你好" |
| encoded = enc.encode_ordinary(test) |
| decoded = enc.decode(encoded) |
| assert decoded == test, f"Tokenizer roundtrip failed: {test!r} -> {decoded!r}" |
| print(f"Tokenizer: sanity check passed (vocab_size={enc.n_vocab})") |
|
|
| |
| |
| |
|
|
| class Tokenizer: |
| """Minimal tokenizer wrapper. Training is handled above.""" |
|
|
| def __init__(self, enc): |
| self.enc = enc |
| self.bos_token_id = enc.encode_single_token(BOS_TOKEN) |
|
|
| @classmethod |
| def from_directory(cls, tokenizer_dir=TOKENIZER_DIR): |
| with open(os.path.join(tokenizer_dir, "tokenizer.pkl"), "rb") as f: |
| enc = pickle.load(f) |
| return cls(enc) |
|
|
| def get_vocab_size(self): |
| return self.enc.n_vocab |
|
|
| def get_bos_token_id(self): |
| return self.bos_token_id |
|
|
| def encode(self, text, prepend=None, num_threads=8): |
| if prepend is not None: |
| prepend_id = prepend if isinstance(prepend, int) else self.enc.encode_single_token(prepend) |
| if isinstance(text, str): |
| ids = self.enc.encode_ordinary(text) |
| if prepend is not None: |
| ids.insert(0, prepend_id) |
| elif isinstance(text, list): |
| ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads) |
| if prepend is not None: |
| for row in ids: |
| row.insert(0, prepend_id) |
| else: |
| raise ValueError(f"Invalid input type: {type(text)}") |
| return ids |
|
|
| def decode(self, ids): |
| return self.enc.decode(ids) |
|
|
|
|
| _TOKEN_BYTES_CACHE: dict = {} |
|
|
| def get_token_bytes(device="cpu"): |
| key = str(device) |
| if key not in _TOKEN_BYTES_CACHE: |
| path = os.path.join(TOKENIZER_DIR, "token_bytes.pt") |
| with open(path, "rb") as f: |
| _TOKEN_BYTES_CACHE[key] = torch.load(f, map_location=device) |
| return _TOKEN_BYTES_CACHE[key] |
|
|
|
|
| def _document_batches(split, tokenizer_batch_size=128): |
| """Infinite iterator over document batches from parquet files.""" |
| parquet_paths = list_parquet_files() |
| assert len(parquet_paths) > 0, "No parquet files found. Run prepare.py first." |
| val_path = os.path.join(DATA_DIR, VAL_FILENAME) |
| if split == "train": |
| parquet_paths = [p for p in parquet_paths if p != val_path] |
| assert len(parquet_paths) > 0, "No training shards found." |
| else: |
| parquet_paths = [val_path] |
| epoch = 1 |
| while True: |
| for filepath in parquet_paths: |
| pf = pq.ParquetFile(filepath) |
| for rg_idx in range(pf.num_row_groups): |
| rg = pf.read_row_group(rg_idx) |
| batch = rg.column('text').to_pylist() |
| for i in range(0, len(batch), tokenizer_batch_size): |
| yield batch[i:i+tokenizer_batch_size], epoch |
| epoch += 1 |
|
|
|
|
| def make_dataloader(tokenizer, B, T, split, buffer_size=1000): |
| """ |
| BOS-aligned dataloader with best-fit packing. |
| Every row starts with BOS. Documents packed using best-fit to minimize cropping. |
| When no document fits remaining space, crops shortest doc to fill exactly. |
| 100% utilization (no padding). |
| """ |
| assert split in ["train", "val"] |
| row_capacity = T + 1 |
| batches = _document_batches(split) |
| bos_token = tokenizer.get_bos_token_id() |
| doc_buffer = [] |
| epoch = 1 |
|
|
| def refill_buffer(): |
| nonlocal epoch |
| doc_batch, epoch = next(batches) |
| token_lists = tokenizer.encode(doc_batch, prepend=bos_token) |
| doc_buffer.extend(token_lists) |
|
|
| |
| row_buffer = torch.empty((B, row_capacity), dtype=torch.long) |
| _dev = "cuda" if torch.cuda.is_available() else "cpu" |
| cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=(_dev == "cuda")) |
| gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device=_dev) |
| cpu_inputs = cpu_buffer[:B * T].view(B, T) |
| cpu_targets = cpu_buffer[B * T:].view(B, T) |
| inputs = gpu_buffer[:B * T].view(B, T) |
| targets = gpu_buffer[B * T:].view(B, T) |
|
|
| while True: |
| for row_idx in range(B): |
| pos = 0 |
| while pos < row_capacity: |
| while len(doc_buffer) < buffer_size: |
| refill_buffer() |
|
|
| remaining = row_capacity - pos |
|
|
| |
| best_idx = -1 |
| best_len = 0 |
| for i, doc in enumerate(doc_buffer): |
| doc_len = len(doc) |
| if doc_len <= remaining and doc_len > best_len: |
| best_idx = i |
| best_len = doc_len |
|
|
| if best_idx >= 0: |
| doc = doc_buffer.pop(best_idx) |
| row_buffer[row_idx, pos:pos + len(doc)] = torch.tensor(doc, dtype=torch.long) |
| pos += len(doc) |
| else: |
| |
| shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i])) |
| doc = doc_buffer.pop(shortest_idx) |
| row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long) |
| pos += remaining |
|
|
| cpu_inputs.copy_(row_buffer[:, :-1]) |
| cpu_targets.copy_(row_buffer[:, 1:]) |
| if _dev == "cuda": |
| gpu_buffer.copy_(cpu_buffer, non_blocking=True) |
| else: |
| gpu_buffer.copy_(cpu_buffer) |
| yield inputs, targets, epoch |
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def evaluate_bpb(model, tokenizer, batch_size): |
| """ |
| Bits per byte (BPB): vocab size-independent evaluation metric. |
| Sums per-token cross-entropy (in nats), sums target byte lengths, |
| then converts nats/byte to bits/byte. Special tokens (byte length 0) |
| are excluded from both sums. |
| Uses fixed MAX_SEQ_LEN so results are comparable across configs. |
| |
| Perf: accumulates on GPU (single sync at end), prefetches next batch |
| while current forward runs. |
| """ |
| _dev = next(model.parameters()).device |
| token_bytes = get_token_bytes(device=_dev) |
| val_loader = make_dataloader(tokenizer, batch_size, MAX_SEQ_LEN, "val") |
| steps = EVAL_TOKENS // (batch_size * MAX_SEQ_LEN) |
|
|
| |
| total_nats_t = torch.zeros(1, device=_dev, dtype=torch.float64) |
| total_bytes_t = torch.zeros(1, device=_dev, dtype=torch.int64) |
|
|
| |
| next_batch = next(val_loader) |
| for _ in range(steps): |
| x, y, _epoch = next_batch |
| |
| next_batch = next(val_loader) |
| loss_flat = model(x, y, reduction='none').view(-1) |
| y_flat = y.view(-1) |
| nbytes = token_bytes[y_flat] |
| mask = nbytes > 0 |
| total_nats_t += (loss_flat * mask).sum() |
| total_bytes_t += nbytes.sum() |
|
|
| |
| total_nats = total_nats_t.item() |
| total_bytes = total_bytes_t.item() |
| return total_nats / (math.log(2) * total_bytes) |
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Prepare data and tokenizer for autoresearch") |
| parser.add_argument("--num-shards", type=int, default=10, help="Number of training shards to download (-1 = all). Val shard is always pinned.") |
| parser.add_argument("--download-workers", type=int, default=8, help="Number of parallel download workers") |
| args = parser.parse_args() |
|
|
| num_shards = MAX_SHARD if args.num_shards == -1 else args.num_shards |
|
|
| print(f"Cache directory: {CACHE_DIR}") |
| print() |
|
|
| |
| download_data(num_shards, download_workers=args.download_workers) |
| print() |
|
|
| |
| train_tokenizer() |
| print() |
| print("Done! Ready to train.") |
|
|