| """Real data loading for WrinkleBrane training. |
| |
| Byte-level tokenization — each byte is a token (vocab_size=259): |
| 0 = PAD |
| 1 = BOS |
| 2 = EOS |
| 3..258 = byte 0x00..0xFF |
| |
| Data files are loaded, concatenated (with EOS between documents), and served |
| as random-offset chunks for next-token prediction. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| import random |
| from pathlib import Path |
| from typing import List, Tuple, Optional |
|
|
| import torch |
| from torch import Tensor |
|
|
|
|
| |
| PAD_ID = 0 |
| BOS_ID = 1 |
| EOS_ID = 2 |
| BYTE_OFFSET = 3 |
| VOCAB_SIZE = 259 |
|
|
|
|
| def encode_bytes(text: str) -> List[int]: |
| """Encode a string to byte-level token IDs.""" |
| return [b + BYTE_OFFSET for b in text.encode("utf-8", errors="replace")] |
|
|
|
|
| def decode_tokens(ids: List[int]) -> str: |
| """Decode token IDs back to a string.""" |
| raw = [] |
| for i in ids: |
| if i >= BYTE_OFFSET: |
| raw.append(i - BYTE_OFFSET) |
| |
| return bytes(raw).decode("utf-8", errors="replace") |
|
|
|
|
| class ByteCorpus: |
| """Holds a tokenised corpus in a flat int32 tensor. |
| |
| Each document is wrapped with BOS/EOS markers. |
| Random chunks can be drawn for training. |
| """ |
|
|
| def __init__(self, token_ids: Tensor): |
| """ |
| Parameters |
| ---------- |
| token_ids : Tensor [N] |
| Flat 1-D tensor of all token IDs. |
| """ |
| self.data = token_ids |
| self.length = len(token_ids) |
|
|
| @classmethod |
| def from_files(cls, paths: List[str], shuffle_docs: bool = True) -> "ByteCorpus": |
| """Load and tokenise multiple text files. |
| |
| Documents within each file are split on ``<|endoftext|>`` markers |
| if present, otherwise the whole file is one document. |
| """ |
| documents = [] |
|
|
| for path in paths: |
| text = Path(path).read_text(encoding="utf-8", errors="replace") |
|
|
| |
| if "<|endoftext|>" in text: |
| parts = text.split("<|endoftext|>") |
| for part in parts: |
| part = part.strip() |
| if part: |
| documents.append(part) |
| else: |
| |
| documents.append(text.strip()) |
|
|
| if shuffle_docs: |
| random.shuffle(documents) |
|
|
| |
| all_ids = [] |
| for doc in documents: |
| all_ids.append(BOS_ID) |
| all_ids.extend(encode_bytes(doc)) |
| all_ids.append(EOS_ID) |
|
|
| token_ids = torch.tensor(all_ids, dtype=torch.long) |
| print(f" Corpus: {len(documents)} documents, " |
| f"{len(token_ids):,} tokens ({len(token_ids)*4/1e6:.1f}MB)") |
|
|
| return cls(token_ids) |
|
|
| def get_batch(self, batch_size: int, seq_len: int) -> Tuple[Tensor, Tensor]: |
| """Sample random chunks for next-token prediction. |
| |
| Returns |
| ------- |
| input_ids : Tensor [B, seq_len] |
| target_ids : Tensor [B, seq_len] |
| Shifted by one position. |
| """ |
| max_start = self.length - seq_len - 1 |
| starts = torch.randint(0, max_start, (batch_size,)) |
|
|
| input_ids = torch.stack([self.data[s:s + seq_len] for s in starts]) |
| target_ids = torch.stack([self.data[s + 1:s + 1 + seq_len] for s in starts]) |
|
|
| return input_ids, target_ids |
|
|
|
|
| def load_train_val( |
| data_dir: str, |
| shuffle: bool = True, |
| ) -> Tuple[ByteCorpus, ByteCorpus]: |
| """Load train and validation corpora from the raw data directory. |
| |
| Training data: all files except tinystories_valid.txt |
| Validation data: tinystories_valid.txt |
| |
| Parameters |
| ---------- |
| data_dir : str |
| Path to the raw data directory. |
| shuffle : bool |
| Whether to shuffle documents within train. |
| """ |
| data_dir = Path(data_dir) |
|
|
| train_files = [ |
| str(data_dir / "tinystories_train.txt"), |
| str(data_dir / "math_data.txt"), |
| str(data_dir / "logic_reasoning.txt"), |
| str(data_dir / "code_snippets.txt"), |
| str(data_dir / "ascii_tables.txt"), |
| str(data_dir / "byte_tables.txt"), |
| ] |
| val_files = [ |
| str(data_dir / "tinystories_valid.txt"), |
| ] |
|
|
| |
| train_files = [f for f in train_files if os.path.exists(f)] |
| val_files = [f for f in val_files if os.path.exists(f)] |
|
|
| print("Loading training data...") |
| train_corpus = ByteCorpus.from_files(train_files, shuffle_docs=shuffle) |
| print("Loading validation data...") |
| val_corpus = ByteCorpus.from_files(val_files, shuffle_docs=False) |
|
|
| return train_corpus, val_corpus |
|
|