nanogpt-tutorial / prepare.py
yat343's picture
Upload prepare.py
06697e8 verified
"""
Step-by-step data preparation for nano GPT.
We work at the CHARACTER LEVEL:
1. Load the tiny Shakespeare text file
2. Discover all unique characters (our vocabulary)
3. Build encoder (char -> int) and decoder (int -> char)
4. Encode the entire text into integers
5. Split into train (90%) and val (10%)
6. Save as PyTorch tensors for fast loading during training
"""
import torch
import os
# ---------------------------------------------------------------------------
# 1. Load the raw text
# ---------------------------------------------------------------------------
DATA_FILE = os.path.join(os.path.dirname(__file__), "input.txt")
with open(DATA_FILE, "r", encoding="utf-8") as f:
text = f.read()
print(f"Total characters in dataset: {len(text):,}")
print(f"First 200 chars:\n{text[:200]}\n")
# ---------------------------------------------------------------------------
# 2. Build the vocabulary
# ---------------------------------------------------------------------------
# We find every unique character and sort them to get a stable ordering.
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(f"Vocabulary size (unique chars): {vocab_size}")
print(f"Characters: {''.join(chars)}")
# ---------------------------------------------------------------------------
# 3. Create encoder / decoder mappings
# ---------------------------------------------------------------------------
stoi = {ch: i for i, ch in enumerate(chars)} # string to int
itos = {i: ch for i, ch in enumerate(chars)} # int to string
# Functions
encode = lambda s: [stoi[c] for c in s] # take a string, return list of ints
decode = lambda l: "".join([itos[i] for i in l]) # take list of ints, return string
# Quick sanity check
assert decode(encode("hello")) == "hello"
print("\nEncode 'hello':", encode("hello"))
print("Decode back :", decode(encode("hello")))
# ---------------------------------------------------------------------------
# 4. Encode the entire dataset
# ---------------------------------------------------------------------------
data = torch.tensor(encode(text), dtype=torch.long)
print(f"\nEncoded data tensor shape: {data.shape}, dtype: {data.dtype}")
# ---------------------------------------------------------------------------
# 5. Train / val split
# ---------------------------------------------------------------------------
n = int(0.9 * len(data)) # first 90% for training
train_data = data[:n]
val_data = data[n:]
print(f"Train tokens: {len(train_data):,}")
print(f"Val tokens : {len(val_data):,}")
# ---------------------------------------------------------------------------
# 6. Save the processed data and metadata
# ---------------------------------------------------------------------------
# We save everything needed for training so the train script doesn't
# need to know about the original text file.
torch.save({
"train": train_data,
"val": val_data,
"vocab_size": vocab_size,
"chars": chars,
"stoi": stoi,
"itos": itos,
}, os.path.join(os.path.dirname(__file__), "data.pt"))
print("\nSaved: data.pt")
print("All done! Ready for training.")