File size: 3,522 Bytes
2d3ab90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""
Milestone 1: Character-level tokenizer and data loading.

Reads tiny Shakespeare, builds a vocab from all unique characters,
provides encode/decode, and a get_batch() function for training.
"""

import torch

# ── Config ────────────────────────────────────────────────────────────────────
DATA_PATH  = "data/input.txt"
BLOCK_SIZE = 256   # context length (tokens per sample)
BATCH_SIZE = 64    # samples per batch
if torch.cuda.is_available():
    DEVICE = "cuda"
elif torch.backends.mps.is_available():
    DEVICE = "mps"
else:
    DEVICE = "cpu"

# ── Load data ─────────────────────────────────────────────────────────────────
with open(DATA_PATH, "r") as f:
    text = f.read()

# ── Build vocab ───────────────────────────────────────────────────────────────
chars  = sorted(set(text))
VOCAB_SIZE = len(chars)

stoi = {ch: i for i, ch in enumerate(chars)}   # char -> int
itos = {i: ch for i, ch in enumerate(chars)}   # int  -> char

def encode(s: str) -> list[int]:
    return [stoi[c] for c in s]

def decode(ids: list[int]) -> str:
    return "".join(itos[i] for i in ids)

# ── Train / val split ─────────────────────────────────────────────────────────
data   = torch.tensor(encode(text), dtype=torch.long)
n      = int(0.9 * len(data))
train_data = data[:n]
val_data   = data[n:]

# ── Batch sampler ─────────────────────────────────────────────────────────────
def get_batch(split: str):
    """Return a random batch of (x, y) pairs.

    x: (BATCH_SIZE, BLOCK_SIZE) input token ids
    y: (BATCH_SIZE, BLOCK_SIZE) target token ids (x shifted right by 1)
    """
    src = train_data if split == "train" else val_data
    ix  = torch.randint(len(src) - BLOCK_SIZE, (BATCH_SIZE,))
    x   = torch.stack([src[i     : i + BLOCK_SIZE    ] for i in ix])
    y   = torch.stack([src[i + 1 : i + BLOCK_SIZE + 1] for i in ix])
    return x.to(DEVICE), y.to(DEVICE)


# ── Quick sanity check ────────────────────────────────────────────────────────
if __name__ == "__main__":
    print(f"Dataset length : {len(text):,} characters")
    print(f"Vocab size     : {VOCAB_SIZE} unique chars")
    print(f"Train tokens   : {len(train_data):,}")
    print(f"Val tokens     : {len(val_data):,}")
    print(f"Device         : {DEVICE}")
    print()

    sample = text[:100]
    encoded = encode(sample)
    decoded = decode(encoded)
    print(f"Sample text    : {repr(sample[:40])}")
    print(f"Encoded[:10]   : {encoded[:10]}")
    print(f"Round-trip OK  : {decoded == sample}")
    print()

    x, y = get_batch("train")
    print(f"Batch x shape  : {x.shape}  (on {x.device})")
    print(f"Batch y shape  : {y.shape}  (on {y.device})")
    print(f"x[0,:8]        : {x[0,:8].tolist()}")
    print(f"y[0,:8]        : {y[0,:8].tolist()}  (x shifted by 1)")