chatbot / train_script_3.py
frc 10252
add files
3905c4a
import torch
import torch.nn as nn
from torch.nn import functional as F
import math, time, os
from torch.utils.data import Dataset, DataLoader
import tiktoken
# from torch.cuda.amp import autocast, GradScaler
from torch.amp.autocast_mode import autocast
from torch.amp.grad_scaler import GradScaler
from tqdm import tqdm
from datasets import load_dataset
from components.model import GPTModel
from components.dataset import TextDataset
# Load dataset
dataset = load_dataset("starhopp3r/TinyChat")
print(
dataset["train"][100]["text"][:500]
) # Print the first 500 characters of the first article
print(dataset["train"][600000])
tokenizer = tiktoken.get_encoding("gpt2")
base_encoding = tiktoken.get_encoding("gpt2")
special_tokens = {
"[INST]": base_encoding.n_vocab, # next available token id
"[/INST]": base_encoding.n_vocab + 1,
}
# 3. Create a new encoding that merges GPT‑2’s tokens + your special tokens
tokenizer = tiktoken.Encoding(
name="gpt2_with_inst",
pat_str=base_encoding._pat_str,
mergeable_ranks=base_encoding._mergeable_ranks,
special_tokens={**base_encoding._special_tokens, **special_tokens},
)
def encode(text):
return tokenizer.encode(text, allowed_special={"[INST]", "[/INST]"})
def decode(tokens):
return tokenizer.decode(tokens)
print("testing encoding and decoding functions:")
print(encode("[INST] Hello, world! [/INST]"))
print(decode(encode("[INST] Hello, world! [/INST]")))
# hyperparameters
train_model = True
periodic_outputs = False
block_size = 128
n_layers = 16
n_heads = 8
dropout_p = 0.1
batch_size = 64
learning_rate = 3e-4
n_embedding = 256
max_iters = 400000
device = "cuda" if torch.cuda.is_available() else "cpu"
train_dataset = TextDataset(dataset, block_size=block_size)
train_dataloader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=16
)
# define objects
vocab_size = tokenizer.n_vocab
model = GPTModel(vocab_size, n_embedding, n_layers, n_heads, dropout_p, block_size).to(
device
)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()
# training loop
torch.set_float32_matmul_precision("high")
scaler = GradScaler(device)
if train_model:
compiled_model = torch.compile(model)
pbar = tqdm(range(max_iters), desc="Training", ncols=100)
data_iter = iter(train_dataloader)
for count in pbar:
try:
xb, yb = next(data_iter)
except StopIteration:
# dataloader exhausted — restart it
data_iter = iter(train_dataloader)
xb, yb = next(data_iter)
if count % 100 == 0 and periodic_outputs:
# print out xb, yb, encoded too
print("xb decoded: ", decode(xb[0].tolist()))
print("yb decoded: ", decode(yb[0].tolist()))
print("---" * 10)
print("xb raw: ", xb[0].tolist())
print("yb raw: ", yb[0].tolist())
#
# except StopIteration:
# break # dataloader exhausted before max_iters
xb, yb = xb.to(device), yb.to(device)
# logits = compiled_model(xb)
# loss = loss_fn(logits.view(-1, vocab_size), yb.view(-1))
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
with autocast(device, dtype=torch.float16):
logits = compiled_model(xb)
loss = loss_fn(logits.view(-1, vocab_size), yb.view(-1))
# backward pass with gradient scaling
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
# update bar text dynamically
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
if train_model:
torch.save(model.state_dict(), "checkpoints/gpt_model-1.pth")
else:
model.load_state_dict(torch.load("checkpoints/gpt_model-1.pth"))
@torch.no_grad()
def generate_text(model, prompt, max_new_tokens, block_size, device):
model.eval()
# Encode the prompt text into token IDs using our custom encode function
tokens = torch.tensor(encode(prompt), dtype=torch.long).unsqueeze(0).to(device)
for _ in range(max_new_tokens):
# Only keep the last block_size tokens for context
input_tokens = tokens[:, -block_size:]
# Get logits and take the last token's distribution
logits = model(input_tokens)
logits = logits[:, -1, :] # (batch=1, vocab)
probs = F.softmax(logits, dim=-1)
# Sample from the distribution
next_token = torch.multinomial(probs, num_samples=1)
tokens = torch.cat((tokens, next_token), dim=1)
# Decode back into text using our custom decode function
output_tokens = tokens[0].tolist()
output_text = decode(output_tokens)
return output_text
# print model parameters
print(
f"Model has {sum(p.numel() for p in model.parameters()) / 1000000:.6f} million parameters."
)
prompt = "this new company does [/INST]"
print(
generate_text(
model, prompt, max_new_tokens=500, block_size=block_size, device=device
)
)