BDHMODEL / train.py
vesakkivignesh's picture
Initial clean Space commit
d7d8e66
# Copyright Pathway Technology, Inc.
import os
from contextlib import nullcontext
import bdh
import numpy as np
import requests
import torch
import torch.nn as nn
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# On a Mac you can also try
# device=torch.device('mps')
dtype = (
"bfloat16"
if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
else "float16"
) # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
ptdtype = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
"float16": torch.float16,
}[dtype]
ctx = (
torch.amp.autocast(device_type=device.type, dtype=ptdtype)
if "cuda" in device.type
else nullcontext()
)
scaler = torch.amp.GradScaler(device=device.type, enabled=(dtype == "float16"))
torch.manual_seed(1337)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
print(f"Using device: {device} with dtype {dtype}")
# Configuration
BDH_CONFIG = bdh.BDHConfig()
BLOCK_SIZE = 512
BATCH_SIZE = 32
MAX_ITERS = 3000
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 0.1
LOG_FREQ = 100
input_file_path = os.path.join(os.path.dirname(__file__), "input.txt")
# Fetch the tiny Shakespeare dataset
def fetch_data():
if not os.path.exists(input_file_path):
data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
with open(input_file_path, "w") as f:
f.write(requests.get(data_url).text)
def get_batch(split):
# treat the file as bytes
data = np.memmap(input_file_path, dtype=np.uint8, mode="r")
if split == "train":
data = data[: int(0.9 * len(data))]
else:
data = data[int(0.9 * len(data)) :]
ix = torch.randint(len(data) - BLOCK_SIZE, (BATCH_SIZE,))
x = torch.stack(
[torch.from_numpy((data[i : i + BLOCK_SIZE]).astype(np.int64)) for i in ix]
)
y = torch.stack(
[
torch.from_numpy((data[i + 1 : i + 1 + BLOCK_SIZE]).astype(np.int64))
for i in ix
]
)
if torch.cuda.is_available():
# pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(
device, non_blocking=True
)
else:
x, y = x.to(device), y.to(device)
return x, y
def eval(model):
model.eval()
if __name__ == "__main__":
fetch_data()
model = bdh.BDH(BDH_CONFIG).to(device)
model = torch.compile(model)
optimizer = torch.optim.AdamW(
model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)
x, y = get_batch("train")
loss_acc = 0
loss_steps = 0
for step in range(MAX_ITERS):
with ctx:
logits, loss = model(x, y)
x, y = get_batch("train")
loss_acc += loss
loss_steps += 1
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
if step % LOG_FREQ == 0:
print(f"Step: {step}/{MAX_ITERS} loss {loss_acc.item() / loss_steps:.3}")
loss_acc = 0
loss_steps = 0
print("Training done, now generating a sample ")
model.eval()
prompt = torch.tensor(
bytearray("To be or ", "utf-8"), dtype=torch.long, device=device
).unsqueeze(0)
ret = model.generate(prompt, max_new_tokens=100, top_k=3)
ret_decoded = bytes(ret.to(torch.uint8).to("cpu").squeeze(0)).decode(
errors="backslashreplace"
)
print(ret_decoded)