|
|
|
|
|
|
|
|
import os |
|
|
from contextlib import nullcontext |
|
|
|
|
|
import bdh |
|
|
import numpy as np |
|
|
import requests |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
|
|
|
dtype = ( |
|
|
"bfloat16" |
|
|
if torch.cuda.is_available() and torch.cuda.is_bf16_supported() |
|
|
else "float16" |
|
|
) |
|
|
ptdtype = { |
|
|
"float32": torch.float32, |
|
|
"bfloat16": torch.bfloat16, |
|
|
"float16": torch.float16, |
|
|
}[dtype] |
|
|
ctx = ( |
|
|
torch.amp.autocast(device_type=device.type, dtype=ptdtype) |
|
|
if "cuda" in device.type |
|
|
else nullcontext() |
|
|
) |
|
|
scaler = torch.amp.GradScaler(device=device.type, enabled=(dtype == "float16")) |
|
|
torch.manual_seed(1337) |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
print(f"Using device: {device} with dtype {dtype}") |
|
|
|
|
|
|
|
|
|
|
|
BDH_CONFIG = bdh.BDHConfig() |
|
|
BLOCK_SIZE = 512 |
|
|
BATCH_SIZE = 32 |
|
|
MAX_ITERS = 3000 |
|
|
LEARNING_RATE = 1e-3 |
|
|
WEIGHT_DECAY = 0.1 |
|
|
LOG_FREQ = 100 |
|
|
|
|
|
input_file_path = os.path.join(os.path.dirname(__file__), "input.txt") |
|
|
|
|
|
|
|
|
|
|
|
def fetch_data(): |
|
|
if not os.path.exists(input_file_path): |
|
|
data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" |
|
|
with open(input_file_path, "w") as f: |
|
|
f.write(requests.get(data_url).text) |
|
|
|
|
|
|
|
|
def get_batch(split): |
|
|
|
|
|
data = np.memmap(input_file_path, dtype=np.uint8, mode="r") |
|
|
if split == "train": |
|
|
data = data[: int(0.9 * len(data))] |
|
|
else: |
|
|
data = data[int(0.9 * len(data)) :] |
|
|
ix = torch.randint(len(data) - BLOCK_SIZE, (BATCH_SIZE,)) |
|
|
x = torch.stack( |
|
|
[torch.from_numpy((data[i : i + BLOCK_SIZE]).astype(np.int64)) for i in ix] |
|
|
) |
|
|
y = torch.stack( |
|
|
[ |
|
|
torch.from_numpy((data[i + 1 : i + 1 + BLOCK_SIZE]).astype(np.int64)) |
|
|
for i in ix |
|
|
] |
|
|
) |
|
|
if torch.cuda.is_available(): |
|
|
|
|
|
x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to( |
|
|
device, non_blocking=True |
|
|
) |
|
|
else: |
|
|
x, y = x.to(device), y.to(device) |
|
|
return x, y |
|
|
|
|
|
|
|
|
def eval(model): |
|
|
model.eval() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
fetch_data() |
|
|
|
|
|
model = bdh.BDH(BDH_CONFIG).to(device) |
|
|
model = torch.compile(model) |
|
|
optimizer = torch.optim.AdamW( |
|
|
model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY |
|
|
) |
|
|
|
|
|
x, y = get_batch("train") |
|
|
|
|
|
loss_acc = 0 |
|
|
loss_steps = 0 |
|
|
for step in range(MAX_ITERS): |
|
|
with ctx: |
|
|
logits, loss = model(x, y) |
|
|
x, y = get_batch("train") |
|
|
loss_acc += loss |
|
|
loss_steps += 1 |
|
|
scaler.scale(loss).backward() |
|
|
scaler.step(optimizer) |
|
|
scaler.update() |
|
|
optimizer.zero_grad() |
|
|
if step % LOG_FREQ == 0: |
|
|
print(f"Step: {step}/{MAX_ITERS} loss {loss_acc.item() / loss_steps:.3}") |
|
|
loss_acc = 0 |
|
|
loss_steps = 0 |
|
|
print("Training done, now generating a sample ") |
|
|
model.eval() |
|
|
prompt = torch.tensor( |
|
|
bytearray("To be or ", "utf-8"), dtype=torch.long, device=device |
|
|
).unsqueeze(0) |
|
|
ret = model.generate(prompt, max_new_tokens=100, top_k=3) |
|
|
ret_decoded = bytes(ret.to(torch.uint8).to("cpu").squeeze(0)).decode( |
|
|
errors="backslashreplace" |
|
|
) |
|
|
print(ret_decoded) |