File size: 16,055 Bytes
12ab29e c475135 12ab29e c475135 12ab29e c475135 12ab29e c475135 12ab29e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 | """
One-time data preparation for autoresearch experiments.
Downloads data shards and trains a BPE tokenizer.
Usage:
python prepare.py # full prep (download + tokenizer)
python prepare.py --num-shards 8 # download only 8 shards (for testing)
Data and tokenizer are stored in ~/.cache/autoresearch/.
"""
import os
import sys
import time
import math
import argparse
import pickle
from multiprocessing import Pool
import requests
import pyarrow.parquet as pq
import rustbpe
import tiktoken
import torch
# ---------------------------------------------------------------------------
# Constants (fixed, do not modify)
# ---------------------------------------------------------------------------
MAX_SEQ_LEN = int(os.environ.get("HYDRA_SEQ_LEN", "512")) # context length
TIME_BUDGET = 300 # training time budget in seconds (5 minutes)
EVAL_TOKENS = 40 * 524288 # number of tokens for val eval
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "autoresearch")
DATA_DIR = os.path.join(CACHE_DIR, "data")
TOKENIZER_DIR = os.path.join(CACHE_DIR, "tokenizer")
BASE_URL = "https://huggingface.co/datasets/karpathy/climbmix-400b-shuffle/resolve/main"
MAX_SHARD = 6542 # the last datashard is shard_06542.parquet
VAL_SHARD = MAX_SHARD # pinned validation shard (shard_06542)
VAL_FILENAME = f"shard_{VAL_SHARD:05d}.parquet"
VOCAB_SIZE = int(os.environ.get("HYDRA_VOCAB_SIZE", "65536")) # 64k — production-grade (was 8k experimental)
# BPE split pattern (GPT-4 style, with \p{N}{1,2} instead of {1,3})
SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
SPECIAL_TOKENS = [f"<|reserved_{i}|>" for i in range(4)]
BOS_TOKEN = "<|reserved_0|>"
# ---------------------------------------------------------------------------
# Data download
# ---------------------------------------------------------------------------
def download_single_shard(index):
"""Download one parquet shard with retries. Returns True on success."""
filename = f"shard_{index:05d}.parquet"
filepath = os.path.join(DATA_DIR, filename)
if os.path.exists(filepath):
return True
url = f"{BASE_URL}/{filename}"
max_attempts = 5
for attempt in range(1, max_attempts + 1):
try:
response = requests.get(url, stream=True, timeout=30)
response.raise_for_status()
temp_path = filepath + ".tmp"
with open(temp_path, "wb") as f:
for chunk in response.iter_content(chunk_size=1024 * 1024):
if chunk:
f.write(chunk)
os.rename(temp_path, filepath)
print(f" Downloaded {filename}")
return True
except (requests.RequestException, IOError) as e:
print(f" Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
for path in [filepath + ".tmp", filepath]:
if os.path.exists(path):
try:
os.remove(path)
except OSError:
pass
if attempt < max_attempts:
time.sleep(2 ** attempt)
return False
def download_data(num_shards, download_workers=8):
"""Download training shards + pinned validation shard."""
os.makedirs(DATA_DIR, exist_ok=True)
num_train = min(num_shards, MAX_SHARD)
ids = list(range(num_train))
if VAL_SHARD not in ids:
ids.append(VAL_SHARD)
# Count what's already downloaded
existing = sum(1 for i in ids if os.path.exists(os.path.join(DATA_DIR, f"shard_{i:05d}.parquet")))
if existing == len(ids):
print(f"Data: all {len(ids)} shards already downloaded at {DATA_DIR}")
return
needed = len(ids) - existing
print(f"Data: downloading {needed} shards ({existing} already exist)...")
workers = max(1, min(download_workers, needed))
with Pool(processes=workers) as pool:
results = pool.map(download_single_shard, ids)
ok = sum(1 for r in results if r)
print(f"Data: {ok}/{len(ids)} shards ready at {DATA_DIR}")
# ---------------------------------------------------------------------------
# Tokenizer training
# ---------------------------------------------------------------------------
def list_parquet_files():
"""Return sorted list of parquet file paths in the data directory."""
files = sorted(f for f in os.listdir(DATA_DIR) if f.endswith(".parquet") and not f.endswith(".tmp"))
return [os.path.join(DATA_DIR, f) for f in files]
def text_iterator(max_chars=1_000_000_000, doc_cap=10_000):
"""Yield documents from training split (all shards except pinned val shard)."""
parquet_paths = [p for p in list_parquet_files() if not p.endswith(VAL_FILENAME)]
nchars = 0
for filepath in parquet_paths:
pf = pq.ParquetFile(filepath)
for rg_idx in range(pf.num_row_groups):
rg = pf.read_row_group(rg_idx)
for text in rg.column("text").to_pylist():
doc = text[:doc_cap] if len(text) > doc_cap else text
nchars += len(doc)
yield doc
if nchars >= max_chars:
return
def train_tokenizer():
"""Train BPE tokenizer using rustbpe, save as tiktoken pickle."""
tokenizer_pkl = os.path.join(TOKENIZER_DIR, "tokenizer.pkl")
token_bytes_path = os.path.join(TOKENIZER_DIR, "token_bytes.pt")
if os.path.exists(tokenizer_pkl) and os.path.exists(token_bytes_path):
print(f"Tokenizer: already trained at {TOKENIZER_DIR}")
return
os.makedirs(TOKENIZER_DIR, exist_ok=True)
parquet_files = list_parquet_files()
if len(parquet_files) < 2:
print("Tokenizer: need at least 2 data shards (1 train + 1 val). Download more data first.")
sys.exit(1)
# --- Train with rustbpe ---
print("Tokenizer: training BPE tokenizer...")
t0 = time.time()
tokenizer = rustbpe.Tokenizer()
vocab_size_no_special = VOCAB_SIZE - len(SPECIAL_TOKENS)
tokenizer.train_from_iterator(text_iterator(), vocab_size_no_special, pattern=SPLIT_PATTERN)
# Build tiktoken encoding from trained merges
pattern = tokenizer.get_pattern()
mergeable_ranks = {bytes(k): v for k, v in tokenizer.get_mergeable_ranks()}
tokens_offset = len(mergeable_ranks)
special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)}
enc = tiktoken.Encoding(
name="rustbpe",
pat_str=pattern,
mergeable_ranks=mergeable_ranks,
special_tokens=special_tokens,
)
# Save tokenizer
with open(tokenizer_pkl, "wb") as f:
pickle.dump(enc, f)
t1 = time.time()
print(f"Tokenizer: trained in {t1 - t0:.1f}s, saved to {tokenizer_pkl}")
# --- Build token_bytes lookup for BPB evaluation ---
print("Tokenizer: building token_bytes lookup...")
special_set = set(SPECIAL_TOKENS)
token_bytes_list = []
for token_id in range(enc.n_vocab):
token_str = enc.decode([token_id])
if token_str in special_set:
token_bytes_list.append(0)
else:
token_bytes_list.append(len(token_str.encode("utf-8")))
token_bytes_tensor = torch.tensor(token_bytes_list, dtype=torch.int32)
torch.save(token_bytes_tensor, token_bytes_path)
print(f"Tokenizer: saved token_bytes to {token_bytes_path}")
# Sanity check
test = "Hello world! Numbers: 123. Unicode: 你好"
encoded = enc.encode_ordinary(test)
decoded = enc.decode(encoded)
assert decoded == test, f"Tokenizer roundtrip failed: {test!r} -> {decoded!r}"
print(f"Tokenizer: sanity check passed (vocab_size={enc.n_vocab})")
# ---------------------------------------------------------------------------
# Runtime utilities (imported by train.py)
# ---------------------------------------------------------------------------
class Tokenizer:
"""Minimal tokenizer wrapper. Training is handled above."""
def __init__(self, enc):
self.enc = enc
self.bos_token_id = enc.encode_single_token(BOS_TOKEN)
@classmethod
def from_directory(cls, tokenizer_dir=TOKENIZER_DIR):
with open(os.path.join(tokenizer_dir, "tokenizer.pkl"), "rb") as f:
enc = pickle.load(f)
return cls(enc)
def get_vocab_size(self):
return self.enc.n_vocab
def get_bos_token_id(self):
return self.bos_token_id
def encode(self, text, prepend=None, num_threads=8):
if prepend is not None:
prepend_id = prepend if isinstance(prepend, int) else self.enc.encode_single_token(prepend)
if isinstance(text, str):
ids = self.enc.encode_ordinary(text)
if prepend is not None:
ids.insert(0, prepend_id)
elif isinstance(text, list):
ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads)
if prepend is not None:
for row in ids:
row.insert(0, prepend_id)
else:
raise ValueError(f"Invalid input type: {type(text)}")
return ids
def decode(self, ids):
return self.enc.decode(ids)
_TOKEN_BYTES_CACHE: dict = {}
def get_token_bytes(device="cpu"):
key = str(device)
if key not in _TOKEN_BYTES_CACHE:
path = os.path.join(TOKENIZER_DIR, "token_bytes.pt")
with open(path, "rb") as f:
_TOKEN_BYTES_CACHE[key] = torch.load(f, map_location=device)
return _TOKEN_BYTES_CACHE[key]
def _document_batches(split, tokenizer_batch_size=128):
"""Infinite iterator over document batches from parquet files."""
parquet_paths = list_parquet_files()
assert len(parquet_paths) > 0, "No parquet files found. Run prepare.py first."
val_path = os.path.join(DATA_DIR, VAL_FILENAME)
if split == "train":
parquet_paths = [p for p in parquet_paths if p != val_path]
assert len(parquet_paths) > 0, "No training shards found."
else:
parquet_paths = [val_path]
epoch = 1
while True:
for filepath in parquet_paths:
pf = pq.ParquetFile(filepath)
for rg_idx in range(pf.num_row_groups):
rg = pf.read_row_group(rg_idx)
batch = rg.column('text').to_pylist()
for i in range(0, len(batch), tokenizer_batch_size):
yield batch[i:i+tokenizer_batch_size], epoch
epoch += 1
def make_dataloader(tokenizer, B, T, split, buffer_size=1000):
"""
BOS-aligned dataloader with best-fit packing.
Every row starts with BOS. Documents packed using best-fit to minimize cropping.
When no document fits remaining space, crops shortest doc to fill exactly.
100% utilization (no padding).
"""
assert split in ["train", "val"]
row_capacity = T + 1
batches = _document_batches(split)
bos_token = tokenizer.get_bos_token_id()
doc_buffer = []
epoch = 1
def refill_buffer():
nonlocal epoch
doc_batch, epoch = next(batches)
token_lists = tokenizer.encode(doc_batch, prepend=bos_token)
doc_buffer.extend(token_lists)
# Pre-allocate buffers: [inputs (B*T) | targets (B*T)]
row_buffer = torch.empty((B, row_capacity), dtype=torch.long)
_dev = "cuda" if torch.cuda.is_available() else "cpu"
cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=(_dev == "cuda"))
gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device=_dev)
cpu_inputs = cpu_buffer[:B * T].view(B, T)
cpu_targets = cpu_buffer[B * T:].view(B, T)
inputs = gpu_buffer[:B * T].view(B, T)
targets = gpu_buffer[B * T:].view(B, T)
while True:
for row_idx in range(B):
pos = 0
while pos < row_capacity:
while len(doc_buffer) < buffer_size:
refill_buffer()
remaining = row_capacity - pos
# Find largest doc that fits entirely
best_idx = -1
best_len = 0
for i, doc in enumerate(doc_buffer):
doc_len = len(doc)
if doc_len <= remaining and doc_len > best_len:
best_idx = i
best_len = doc_len
if best_idx >= 0:
doc = doc_buffer.pop(best_idx)
row_buffer[row_idx, pos:pos + len(doc)] = torch.tensor(doc, dtype=torch.long)
pos += len(doc)
else:
# No doc fits — crop shortest to fill remaining
shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
doc = doc_buffer.pop(shortest_idx)
row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long)
pos += remaining
cpu_inputs.copy_(row_buffer[:, :-1])
cpu_targets.copy_(row_buffer[:, 1:])
if _dev == "cuda":
gpu_buffer.copy_(cpu_buffer, non_blocking=True)
else:
gpu_buffer.copy_(cpu_buffer)
yield inputs, targets, epoch
# ---------------------------------------------------------------------------
# Evaluation (DO NOT CHANGE — this is the fixed metric)
# ---------------------------------------------------------------------------
@torch.no_grad()
def evaluate_bpb(model, tokenizer, batch_size):
"""
Bits per byte (BPB): vocab size-independent evaluation metric.
Sums per-token cross-entropy (in nats), sums target byte lengths,
then converts nats/byte to bits/byte. Special tokens (byte length 0)
are excluded from both sums.
Uses fixed MAX_SEQ_LEN so results are comparable across configs.
Perf: accumulates on GPU (single sync at end), prefetches next batch
while current forward runs.
"""
_dev = next(model.parameters()).device
token_bytes = get_token_bytes(device=_dev)
val_loader = make_dataloader(tokenizer, batch_size, MAX_SEQ_LEN, "val")
steps = EVAL_TOKENS // (batch_size * MAX_SEQ_LEN)
# GPU-resident accumulators — avoid per-batch .item() sync
total_nats_t = torch.zeros(1, device=_dev, dtype=torch.float64)
total_bytes_t = torch.zeros(1, device=_dev, dtype=torch.int64)
# Prefetch first batch
next_batch = next(val_loader)
for _ in range(steps):
x, y, _epoch = next_batch
# Prefetch NEXT batch while GPU computes current forward
next_batch = next(val_loader)
loss_flat = model(x, y, reduction='none').view(-1)
y_flat = y.view(-1)
nbytes = token_bytes[y_flat]
mask = nbytes > 0
total_nats_t += (loss_flat * mask).sum()
total_bytes_t += nbytes.sum()
# Single GPU→CPU sync at end
total_nats = total_nats_t.item()
total_bytes = total_bytes_t.item()
return total_nats / (math.log(2) * total_bytes)
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Prepare data and tokenizer for autoresearch")
parser.add_argument("--num-shards", type=int, default=10, help="Number of training shards to download (-1 = all). Val shard is always pinned.")
parser.add_argument("--download-workers", type=int, default=8, help="Number of parallel download workers")
args = parser.parse_args()
num_shards = MAX_SHARD if args.num_shards == -1 else args.num_shards
print(f"Cache directory: {CACHE_DIR}")
print()
# Step 1: Download data
download_data(num_shards, download_workers=args.download_workers)
print()
# Step 2: Train tokenizer
train_tokenizer()
print()
print("Done! Ready to train.")
|