File size: 3,522 Bytes
2d3ab90 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 | """
Milestone 1: Character-level tokenizer and data loading.
Reads tiny Shakespeare, builds a vocab from all unique characters,
provides encode/decode, and a get_batch() function for training.
"""
import torch
# ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
DATA_PATH = "data/input.txt"
BLOCK_SIZE = 256 # context length (tokens per sample)
BATCH_SIZE = 64 # samples per batch
if torch.cuda.is_available():
DEVICE = "cuda"
elif torch.backends.mps.is_available():
DEVICE = "mps"
else:
DEVICE = "cpu"
# ββ Load data βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
with open(DATA_PATH, "r") as f:
text = f.read()
# ββ Build vocab βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
chars = sorted(set(text))
VOCAB_SIZE = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)} # char -> int
itos = {i: ch for i, ch in enumerate(chars)} # int -> char
def encode(s: str) -> list[int]:
return [stoi[c] for c in s]
def decode(ids: list[int]) -> str:
return "".join(itos[i] for i in ids)
# ββ Train / val split βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]
# ββ Batch sampler βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def get_batch(split: str):
"""Return a random batch of (x, y) pairs.
x: (BATCH_SIZE, BLOCK_SIZE) input token ids
y: (BATCH_SIZE, BLOCK_SIZE) target token ids (x shifted right by 1)
"""
src = train_data if split == "train" else val_data
ix = torch.randint(len(src) - BLOCK_SIZE, (BATCH_SIZE,))
x = torch.stack([src[i : i + BLOCK_SIZE ] for i in ix])
y = torch.stack([src[i + 1 : i + BLOCK_SIZE + 1] for i in ix])
return x.to(DEVICE), y.to(DEVICE)
# ββ Quick sanity check ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if __name__ == "__main__":
print(f"Dataset length : {len(text):,} characters")
print(f"Vocab size : {VOCAB_SIZE} unique chars")
print(f"Train tokens : {len(train_data):,}")
print(f"Val tokens : {len(val_data):,}")
print(f"Device : {DEVICE}")
print()
sample = text[:100]
encoded = encode(sample)
decoded = decode(encoded)
print(f"Sample text : {repr(sample[:40])}")
print(f"Encoded[:10] : {encoded[:10]}")
print(f"Round-trip OK : {decoded == sample}")
print()
x, y = get_batch("train")
print(f"Batch x shape : {x.shape} (on {x.device})")
print(f"Batch y shape : {y.shape} (on {y.device})")
print(f"x[0,:8] : {x[0,:8].tolist()}")
print(f"y[0,:8] : {y[0,:8].tolist()} (x shifted by 1)")
|