File size: 4,688 Bytes
3d7f6c5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | """Real data loading for WrinkleBrane training.
Byte-level tokenization — each byte is a token (vocab_size=259):
0 = PAD
1 = BOS
2 = EOS
3..258 = byte 0x00..0xFF
Data files are loaded, concatenated (with EOS between documents), and served
as random-offset chunks for next-token prediction.
"""
from __future__ import annotations
import os
import random
from pathlib import Path
from typing import List, Tuple, Optional
import torch
from torch import Tensor
# Special tokens
PAD_ID = 0
BOS_ID = 1
EOS_ID = 2
BYTE_OFFSET = 3
VOCAB_SIZE = 259 # 3 special + 256 bytes
def encode_bytes(text: str) -> List[int]:
"""Encode a string to byte-level token IDs."""
return [b + BYTE_OFFSET for b in text.encode("utf-8", errors="replace")]
def decode_tokens(ids: List[int]) -> str:
"""Decode token IDs back to a string."""
raw = []
for i in ids:
if i >= BYTE_OFFSET:
raw.append(i - BYTE_OFFSET)
# Skip special tokens in output
return bytes(raw).decode("utf-8", errors="replace")
class ByteCorpus:
"""Holds a tokenised corpus in a flat int32 tensor.
Each document is wrapped with BOS/EOS markers.
Random chunks can be drawn for training.
"""
def __init__(self, token_ids: Tensor):
"""
Parameters
----------
token_ids : Tensor [N]
Flat 1-D tensor of all token IDs.
"""
self.data = token_ids
self.length = len(token_ids)
@classmethod
def from_files(cls, paths: List[str], shuffle_docs: bool = True) -> "ByteCorpus":
"""Load and tokenise multiple text files.
Documents within each file are split on ``<|endoftext|>`` markers
if present, otherwise the whole file is one document.
"""
documents = []
for path in paths:
text = Path(path).read_text(encoding="utf-8", errors="replace")
# Split on endoftext markers if present
if "<|endoftext|>" in text:
parts = text.split("<|endoftext|>")
for part in parts:
part = part.strip()
if part:
documents.append(part)
else:
# Whole file is one document
documents.append(text.strip())
if shuffle_docs:
random.shuffle(documents)
# Encode all documents with BOS/EOS framing
all_ids = []
for doc in documents:
all_ids.append(BOS_ID)
all_ids.extend(encode_bytes(doc))
all_ids.append(EOS_ID)
token_ids = torch.tensor(all_ids, dtype=torch.long)
print(f" Corpus: {len(documents)} documents, "
f"{len(token_ids):,} tokens ({len(token_ids)*4/1e6:.1f}MB)")
return cls(token_ids)
def get_batch(self, batch_size: int, seq_len: int) -> Tuple[Tensor, Tensor]:
"""Sample random chunks for next-token prediction.
Returns
-------
input_ids : Tensor [B, seq_len]
target_ids : Tensor [B, seq_len]
Shifted by one position.
"""
max_start = self.length - seq_len - 1
starts = torch.randint(0, max_start, (batch_size,))
input_ids = torch.stack([self.data[s:s + seq_len] for s in starts])
target_ids = torch.stack([self.data[s + 1:s + 1 + seq_len] for s in starts])
return input_ids, target_ids
def load_train_val(
data_dir: str,
shuffle: bool = True,
) -> Tuple[ByteCorpus, ByteCorpus]:
"""Load train and validation corpora from the raw data directory.
Training data: all files except tinystories_valid.txt
Validation data: tinystories_valid.txt
Parameters
----------
data_dir : str
Path to the raw data directory.
shuffle : bool
Whether to shuffle documents within train.
"""
data_dir = Path(data_dir)
train_files = [
str(data_dir / "tinystories_train.txt"),
str(data_dir / "math_data.txt"),
str(data_dir / "logic_reasoning.txt"),
str(data_dir / "code_snippets.txt"),
str(data_dir / "ascii_tables.txt"),
str(data_dir / "byte_tables.txt"),
]
val_files = [
str(data_dir / "tinystories_valid.txt"),
]
# Only include files that exist
train_files = [f for f in train_files if os.path.exists(f)]
val_files = [f for f in val_files if os.path.exists(f)]
print("Loading training data...")
train_corpus = ByteCorpus.from_files(train_files, shuffle_docs=shuffle)
print("Loading validation data...")
val_corpus = ByteCorpus.from_files(val_files, shuffle_docs=False)
return train_corpus, val_corpus
|