In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import math, time, os
from torch.utils.data import Dataset, DataLoader
import tiktoken
# from torch.cuda.amp import autocast, GradScaler
from torch.amp.autocast_mode import autocast
from torch.amp.grad_scaler import GradScaler
from tqdm import tqdm

In [2]:
from components.dataset import TextDataset
from components.model import GPTModel
from components.tokenizer import encode, decode, tokenizer

 from .autonotebook import tqdm as notebook_tqdm


In [3]:
from datasets import load_dataset

# dataset = load_dataset("wikimedia/wikipedia", "20231101.en")
dataset = load_dataset("starhopp3r/TinyChat")
# This gives you cleaned, plain text articles1
print(dataset['train'][100]['text'][:500]) # Print the first 500 characters of the first article
print(dataset['train'][600000])

[INST] Hello, I feel a bit sad today because things seem hard to understand and move through. [/INST] I understand how you feel; sometimes life can be heavy like a thick substance we cannot lift. [INST] Yes, it can be very difficult, especially for young people trying to find their way. [/INST] Young minds often carry many questions that can weigh them down with worries and doubts. [INST] Sometimes, I wish everything would get better and we could all feel lighter again. [/INST] Hoping for better
{'text': "[INST] Do you think the disease spreading in the city is really as bad as it seems? [/INST] It does seem very clear that many people are crying over the current situation. [INST] Yes, I feel disgusted by how quickly it is spreading without control or care. [/INST] It makes me feel unwell just to think about how people's lives are affected deeply. [INST] I can’t believe some people ignore the danger and spread the disease even more. [/INST] That kind of behavior is truly unhelpful and 

In [4]:
#hyperparameters
train_model = False
block_size = 128
n_layers = 16
n_heads = 8
dropout_p = 0.1
batch_size =8
learning_rate = 3e-4
n_embedding = 256
max_iters = 5000
device = 'cuda' if torch.cuda.is_available() else 'cpu'


In [5]:
# tokenizer = tiktoken.get_encoding("gpt2")

train_dataset = TextDataset(dataset, block_size=block_size)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=16)

In [6]:
# define objects
vocab_size = tokenizer.n_vocab

model = GPTModel(vocab_size, n_embedding, n_layers, n_heads, dropout_p, block_size).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

In [7]:


# training loop
torch.set_float32_matmul_precision('high')
scaler = GradScaler(device)
if train_model:
 compiled_model = torch.compile(model)

 pbar = tqdm(range(max_iters), desc="Training", ncols=100)
 data_iter = iter(train_dataloader)

 for count in pbar:
 # xb, yb = next(data_iter)

 try:
 xb, yb = next(data_iter)
 except StopIteration:
 # dataloader exhausted — restart it
 data_iter = iter(train_dataloader)
 xb, yb = next(data_iter)
 if count%100 == 0:
 # print out xb, yb, encoded too
 print('xb decoded: ', decode(xb[0].tolist())) 
 print('yb decoded: ', decode(yb[0].tolist())) 

 # except StopIteration:
 # break # dataloader exhausted before max_iters
 
 xb, yb = xb.to(device), yb.to(device)
 # logits = compiled_model(xb)
 # loss = loss_fn(logits.view(-1, vocab_size), yb.view(-1))

 # optimizer.zero_grad()
 # loss.backward()
 # optimizer.step()
 with autocast(device, dtype=torch.float16):
 logits = compiled_model(xb)
 loss = loss_fn(logits.view(-1, vocab_size), yb.view(-1))

 # backward pass with gradient scaling
 optimizer.zero_grad()
 scaler.scale(loss).backward()
 scaler.step(optimizer)
 scaler.update()

 # update bar text dynamically
 pbar.set_postfix({"loss": f"{loss.item():.4f}"})

 _C._set_float32_matmul_precision(precision)


In [8]:
if train_model:
 torch.save(model.state_dict(), "checkpoints/gpt_model-1.pth")
else:
 model.load_state_dict(torch.load("checkpoints/gpt_model-1.pth"))

In [12]:
@torch.no_grad()
def generate_text(model, tokenizer, prompt, max_new_tokens, block_size, device):
 model.eval()
 # Encode the prompt text into token IDs
 tokens = torch.tensor(encode(prompt), dtype=torch.long).unsqueeze(0).to(device)

 for _ in range(max_new_tokens):
 # Only keep the last block_size tokens for context
 input_tokens = tokens[:, -block_size:]

 # Get logits and take the last token’s distribution
 logits = model(input_tokens)
 logits = logits[:, -1, :] # (batch=1, vocab)
 probs = F.softmax(logits, dim=-1)

 # Sample from the distribution
 next_token = torch.multinomial(probs, num_samples=1)
 tokens = torch.cat((tokens, next_token), dim=1)

 # Decode back into text
 output_text = tokenizer.decode(tokens[0].tolist())
 return output_text
 
# print model parameters
print (f"Model has {sum(p.numel() for p in model.parameters())/1000000} million parameters.")
prompt = "what do you think of books? [/INST]"
print(generate_text(model, tokenizer, prompt, max_new_tokens=500, block_size=block_size, device=device))

Model has 38.402048 million parameters.
what do you think of books? [/INST] I think a book page can be fun and surprising. [INST] Yes, especially when I find a secret book to read in the pages. [/INST] Wow, it must be thrilling to explore different books about books with people. [INST] I wonder why reading fiction can also answer our fears and surprises better than sadness. [/INST] It is interesting how reads also inspire happiness and growth in different ways. [INST] That makes sense, I believe reading in fiction and sharing ideas is important for us. [/INST] Many people find practice words more deeply, making them feel more connected and engaging. [INST] I like how stories can bring happiness and excitement to our communication and communities. [/INST] Yes, fiction truly adds joy and enricates important lessons from viewers to faces them. [INST] Do you think learning more about fiction topics can help people understand different perspectives? [/INST] Definitely, talking about one ano