abersbail's picture
Add small GPT Python Space
79078fe verified
import random
import torch
from .data import build_training_text
from .model import SmallGPTModel
from .tokenizer import WordTokenizer
def set_seed(seed: int):
random.seed(seed)
torch.manual_seed(seed)
def create_model_and_tokenizer(config, extra_text=""):
text = build_training_text(extra_text)
tokenizer = WordTokenizer().fit(text)
encoded = tokenizer.encode(text, add_bos=True, add_eos=True)
encoded = torch.tensor(encoded, dtype=torch.long)
model = SmallGPTModel(
vocab_size=tokenizer.vocab_size,
block_size=config.block_size,
d_model=config.d_model,
n_heads=config.n_heads,
n_layers=config.n_layers,
dropout=config.dropout,
)
return model, tokenizer, encoded
def build_batch(encoded, block_size, batch_size):
max_start = max(1, len(encoded) - block_size - 1)
starts = torch.randint(0, max_start, (batch_size,))
x = torch.stack([encoded[start : start + block_size] for start in starts])
y = torch.stack([encoded[start + 1 : start + block_size + 1] for start in starts])
return x, y
def train_model(model, encoded, config, steps):
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
model.train()
losses = []
for _ in range(steps):
xb, yb = build_batch(encoded, config.block_size, config.batch_size)
_, loss = model(xb, targets=yb)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(float(loss.item()))
return losses