Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| import math, time, os | |
| from torch.utils.data import Dataset, DataLoader | |
| import tiktoken | |
| # from torch.cuda.amp import autocast, GradScaler | |
| from torch.amp.autocast_mode import autocast | |
| from torch.amp.grad_scaler import GradScaler | |
| from tqdm import tqdm | |
| from datasets import load_dataset | |
| from components.model import GPTModel | |
| from components.dataset import TextDataset | |
| # Load dataset | |
| dataset = load_dataset("starhopp3r/TinyChat") | |
| print( | |
| dataset["train"][100]["text"][:500] | |
| ) # Print the first 500 characters of the first article | |
| print(dataset["train"][600000]) | |
| tokenizer = tiktoken.get_encoding("gpt2") | |
| base_encoding = tiktoken.get_encoding("gpt2") | |
| special_tokens = { | |
| "[INST]": base_encoding.n_vocab, # next available token id | |
| "[/INST]": base_encoding.n_vocab + 1, | |
| } | |
| # 3. Create a new encoding that merges GPT‑2’s tokens + your special tokens | |
| tokenizer = tiktoken.Encoding( | |
| name="gpt2_with_inst", | |
| pat_str=base_encoding._pat_str, | |
| mergeable_ranks=base_encoding._mergeable_ranks, | |
| special_tokens={**base_encoding._special_tokens, **special_tokens}, | |
| ) | |
| def encode(text): | |
| return tokenizer.encode(text, allowed_special={"[INST]", "[/INST]"}) | |
| def decode(tokens): | |
| return tokenizer.decode(tokens) | |
| print("testing encoding and decoding functions:") | |
| print(encode("[INST] Hello, world! [/INST]")) | |
| print(decode(encode("[INST] Hello, world! [/INST]"))) | |
| # hyperparameters | |
| train_model = True | |
| periodic_outputs = False | |
| block_size = 128 | |
| n_layers = 16 | |
| n_heads = 8 | |
| dropout_p = 0.1 | |
| batch_size = 64 | |
| learning_rate = 3e-4 | |
| n_embedding = 256 | |
| max_iters = 400000 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| train_dataset = TextDataset(dataset, block_size=block_size) | |
| train_dataloader = DataLoader( | |
| train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=16 | |
| ) | |
| # define objects | |
| vocab_size = tokenizer.n_vocab | |
| model = GPTModel(vocab_size, n_embedding, n_layers, n_heads, dropout_p, block_size).to( | |
| device | |
| ) | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) | |
| loss_fn = nn.CrossEntropyLoss() | |
| # training loop | |
| torch.set_float32_matmul_precision("high") | |
| scaler = GradScaler(device) | |
| if train_model: | |
| compiled_model = torch.compile(model) | |
| pbar = tqdm(range(max_iters), desc="Training", ncols=100) | |
| data_iter = iter(train_dataloader) | |
| for count in pbar: | |
| try: | |
| xb, yb = next(data_iter) | |
| except StopIteration: | |
| # dataloader exhausted — restart it | |
| data_iter = iter(train_dataloader) | |
| xb, yb = next(data_iter) | |
| if count % 100 == 0 and periodic_outputs: | |
| # print out xb, yb, encoded too | |
| print("xb decoded: ", decode(xb[0].tolist())) | |
| print("yb decoded: ", decode(yb[0].tolist())) | |
| print("---" * 10) | |
| print("xb raw: ", xb[0].tolist()) | |
| print("yb raw: ", yb[0].tolist()) | |
| # | |
| # except StopIteration: | |
| # break # dataloader exhausted before max_iters | |
| xb, yb = xb.to(device), yb.to(device) | |
| # logits = compiled_model(xb) | |
| # loss = loss_fn(logits.view(-1, vocab_size), yb.view(-1)) | |
| # optimizer.zero_grad() | |
| # loss.backward() | |
| # optimizer.step() | |
| with autocast(device, dtype=torch.float16): | |
| logits = compiled_model(xb) | |
| loss = loss_fn(logits.view(-1, vocab_size), yb.view(-1)) | |
| # backward pass with gradient scaling | |
| optimizer.zero_grad() | |
| scaler.scale(loss).backward() | |
| scaler.step(optimizer) | |
| scaler.update() | |
| # update bar text dynamically | |
| pbar.set_postfix({"loss": f"{loss.item():.4f}"}) | |
| if train_model: | |
| torch.save(model.state_dict(), "checkpoints/gpt_model-1.pth") | |
| else: | |
| model.load_state_dict(torch.load("checkpoints/gpt_model-1.pth")) | |
| def generate_text(model, prompt, max_new_tokens, block_size, device): | |
| model.eval() | |
| # Encode the prompt text into token IDs using our custom encode function | |
| tokens = torch.tensor(encode(prompt), dtype=torch.long).unsqueeze(0).to(device) | |
| for _ in range(max_new_tokens): | |
| # Only keep the last block_size tokens for context | |
| input_tokens = tokens[:, -block_size:] | |
| # Get logits and take the last token's distribution | |
| logits = model(input_tokens) | |
| logits = logits[:, -1, :] # (batch=1, vocab) | |
| probs = F.softmax(logits, dim=-1) | |
| # Sample from the distribution | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| tokens = torch.cat((tokens, next_token), dim=1) | |
| # Decode back into text using our custom decode function | |
| output_tokens = tokens[0].tolist() | |
| output_text = decode(output_tokens) | |
| return output_text | |
| # print model parameters | |
| print( | |
| f"Model has {sum(p.numel() for p in model.parameters()) / 1000000:.6f} million parameters." | |
| ) | |
| prompt = "this new company does [/INST]" | |
| print( | |
| generate_text( | |
| model, prompt, max_new_tokens=500, block_size=block_size, device=device | |
| ) | |
| ) | |