Spaces:
Sleeping
Sleeping
| from model import CharacterLevelTokenizer, Config, PotterGPT | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from tqdm import tqdm | |
| from pathlib import Path | |
| from tokenizers import Tokenizer | |
| import matplotlib.pyplot as plt | |
| torch.manual_seed(1357) | |
| with open('data/harry_potter_data', 'r', encoding='utf-8') as f: | |
| data = f.read() | |
| class Dataset: | |
| def __init__(self,Config, is_test=False) -> None: | |
| self.tokenizer = CharacterLevelTokenizer(data) | |
| self.is_test = is_test | |
| self.full_data = self.tokenizer.encode(self.tokenizer.data) | |
| if self.is_test: | |
| self.data = self.full_data[int(0.9*len(self.full_data)):] | |
| else: | |
| self.data = self.full_data[:int(0.9*len(self.full_data))] | |
| self.block_size = Config.block_size | |
| self.batch_size = Config.batch_size | |
| def __len__(self) -> int: | |
| return len(self.data) | |
| def get_block_size(self) -> int: | |
| return self.block_size | |
| def get_vocab_size(self) -> int: | |
| return self.tokenizer.VOCAB_SIZE | |
| def get(self): | |
| ix = torch.randint(len(self.data) - self.block_size, (self.batch_size,)) | |
| x = torch.stack([self.data[i:i+self.block_size] for i in ix]) | |
| y = torch.stack([self.data[i+1:i+self.block_size+1] for i in ix]) | |
| return x,y | |
| # tokenizer = tokenizer = Tokenizer.from_file('tokenizer/potter.json') | |
| tokenizer = CharacterLevelTokenizer(data) | |
| #Training | |
| train_ds = Dataset(Config) | |
| val_ds = Dataset(Config, is_test=True) | |
| lm = PotterGPT(Config) | |
| lm = lm.to(device=Config.device) | |
| optim = torch.optim.Adam(lm.parameters(), lr=Config.lr) | |
| def loss_fn(logits, targets): | |
| B, T, C = logits.shape | |
| logits = logits.view(B*T, C) | |
| targets = targets.view(B*T) | |
| loss = F.cross_entropy(logits, targets) | |
| return loss | |
| def train_N_iters(): | |
| lm.train() | |
| train_step_losses = [] | |
| for batch in tqdm(range(Config.train_iters)): | |
| optim.zero_grad() | |
| inputs, targets = train_ds.get() | |
| inputs, targets = inputs.to(device=Config.device), targets.to(device=Config.device) | |
| logits = lm(inputs) | |
| loss = loss_fn(logits,targets) | |
| loss.backward() | |
| optim.step() | |
| train_step_losses.append(loss.item()) | |
| if batch%(Config.train_iters//10)==0 or batch==Config.train_iters-1: | |
| print(f"batch {batch} train step loss: {loss.item()}") | |
| del inputs, targets, loss, logits | |
| return train_step_losses | |
| def valid_N_iters(): | |
| lm.eval() | |
| val_step_losses = [] | |
| for batch in tqdm(range(Config.val_iters)): | |
| inputs, targets = val_ds.get() | |
| inputs, targets = inputs.to(device=Config.device), targets.to(device=Config.device) | |
| logits = lm(inputs) | |
| loss = loss_fn(logits,targets) | |
| val_step_losses.append(loss.item()) | |
| if batch%(Config.val_iters//10)==0 or batch==Config.val_iters-1: | |
| print(f"batch {batch} valid step loss: {loss.item()}") | |
| del inputs, targets, loss, logits | |
| return val_step_losses | |
| def save_lm(): | |
| state_dict = lm.state_dict() | |
| save_path = Path('./').resolve() / 'potterGPT' | |
| save_path.mkdir(exist_ok=True) | |
| model_path = save_path / f'potterGPT.pth' | |
| torch.save(state_dict, model_path) | |
| def train_lm(): | |
| train_losses = train_N_iters() | |
| valid_losses = valid_N_iters() | |
| save_lm() | |
| return train_losses, valid_losses | |
| tl, vl = train_lm() | |
| plt.plot(tl,label='train loss',color='orange') | |
| plt.plot(vl,label='valid loss',color='blue') | |
| plt.title('Potter GPT Losses') | |
| plt.legend() | |
| plt.show() | |
| generated_texts = [] | |
| for length in [100,300,500,700,1000]: | |
| generated = lm.generate( | |
| torch.zeros((1,1),dtype=torch.long,device=Config.device), # initial context 0 | |
| total=length | |
| ) | |
| generated = tokenizer.decode(generated[0]) | |
| text=f'generated ({length} tokens)\n{"="*50}\n{generated}\n{"="*50}\n\n' | |
| generated_texts.append(text) | |
| print(text) |