Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
| # | |
| # This source code is licensed under the BSD license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # A MinGPT + Lightning + xFormers example Code from Sean Naren (@seannaren) | |
| # This is an hommage to https://github.com/karpathy/minGPT | |
| import math | |
| import os | |
| import pytorch_lightning as pl | |
| import torch | |
| import torch.nn as nn | |
| from pytorch_lightning import Trainer, seed_everything | |
| from pytorch_lightning.utilities import rank_zero_info | |
| from torch.nn import functional as F | |
| from torch.utils.data import DataLoader, Dataset, RandomSampler | |
| from xformers.factory.model_factory import xFormer, xFormerConfig | |
| class GPT(pl.LightningModule): | |
| """the full GPT language model, with a context size of block_size""" | |
| def __init__( | |
| self, | |
| vocab_size, | |
| weight_decay=0.1, | |
| betas=(0.9, 0.95), | |
| learning_rate=6e-4, | |
| n_embd=512, | |
| block_size=128, | |
| n_layer=8, | |
| n_head=8, | |
| resid_pdrop=0.1, | |
| attn_pdrop=0.1, | |
| mlp_pdrop=0.1, | |
| attention="scaled_dot_product", | |
| hidden_layer_multiplier=4, | |
| warmup_tokens=20, | |
| final_tokens=1000, | |
| ): | |
| super().__init__() | |
| # auto creates self.hparams from the method signature | |
| self.save_hyperparameters() | |
| # A list of the encoder or decoder blocks which constitute the Transformer. | |
| xformer_config = [ | |
| { | |
| "reversible": False, # Turn on to test the effect of using reversible layers | |
| "block_type": "encoder", | |
| "num_layers": self.hparams.n_layer, | |
| "dim_model": self.hparams.n_embd, | |
| "residual_norm_style": "post", | |
| "position_encoding_config": { | |
| "name": "vocab", | |
| "seq_len": self.hparams.block_size, | |
| "vocab_size": self.hparams.vocab_size, | |
| }, | |
| "multi_head_config": { | |
| "num_heads": self.hparams.n_head, | |
| "residual_dropout": self.hparams.resid_pdrop, | |
| "use_rotary_embeddings": True, | |
| "attention": { | |
| "name": self.hparams.attention, | |
| "dropout": self.hparams.attn_pdrop, | |
| "causal": True, | |
| "seq_len": self.hparams.block_size, | |
| "num_rules": self.hparams.n_head, | |
| }, | |
| }, | |
| "feedforward_config": { | |
| "name": "MLP", | |
| "dropout": self.hparams.mlp_pdrop, | |
| "activation": "gelu", | |
| "hidden_layer_multiplier": self.hparams.hidden_layer_multiplier, | |
| }, | |
| } | |
| ] | |
| config = xFormerConfig(xformer_config) | |
| config.weight_init = "small" | |
| self.model = xFormer.from_config(config) | |
| # decoder head | |
| self.ln_f = nn.LayerNorm(self.hparams.n_embd) | |
| self.head = nn.Linear(self.hparams.n_embd, self.hparams.vocab_size, bias=False) | |
| self.block_size = self.hparams.block_size | |
| self.apply(self._init_weights) | |
| self._tokens_seen = 0 | |
| def _init_weights(self, module): | |
| if isinstance(module, (nn.Linear, nn.Embedding)): | |
| module.weight.data.normal_(mean=0.0, std=0.02) | |
| if isinstance(module, nn.Linear) and module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.LayerNorm): | |
| module.bias.data.zero_() | |
| module.weight.data.fill_(1.0) | |
| # Reset the token counter | |
| self._tokens_seen = 0 | |
| def get_block_size(self): | |
| return self.block_size | |
| def configure_optimizers(self): | |
| # Create the optimizer and the training schedule: | |
| # - Handle the per-param weight decay | |
| no_decay = ["bias", "LayerNorm.weight"] | |
| params_decay = [ | |
| p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay) | |
| ] | |
| params_nodecay = [ | |
| p for n, p in self.named_parameters() if any(nd in n for nd in no_decay) | |
| ] | |
| optim_groups = [ | |
| {"params": params_decay, "weight_decay": self.hparams.weight_decay}, | |
| {"params": params_nodecay, "weight_decay": 0.0}, | |
| ] | |
| # - Start with a warm up, ramp up then cosine | |
| optimizer = torch.optim.AdamW( | |
| optim_groups, lr=self.hparams.learning_rate, betas=self.hparams.betas | |
| ) | |
| def update_lr(*_): | |
| config = self.hparams | |
| if self._tokens_seen < config.warmup_tokens: | |
| # linear warmup | |
| lr_mult = float(self._tokens_seen) / float(max(1, config.warmup_tokens)) | |
| lr_mult = max(lr_mult, 1e-2) # could be that we've not seen any yet | |
| else: | |
| # cosine learning rate decay | |
| progress = float(self._tokens_seen - config.warmup_tokens) / float( | |
| max(1, config.final_tokens - config.warmup_tokens) | |
| ) | |
| lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress))) | |
| return lr_mult | |
| lr_scheduler = { | |
| "scheduler": torch.optim.lr_scheduler.LambdaLR( | |
| optimizer, | |
| lr_lambda=[update_lr, update_lr], | |
| ), | |
| "name": "learning_rate", | |
| "interval": "step", # The unit of the scheduler's step size | |
| "frequency": 1, # The frequency of the scheduler | |
| } | |
| return [optimizer], [lr_scheduler] | |
| def forward(self, src): | |
| # predict the next tokens (in latent space) | |
| prediction = self.model(src) | |
| # translate the predictions into tokens | |
| prediction = self.ln_f(prediction) | |
| logits = self.head(prediction) | |
| return logits | |
| def training_step(self, batch, _): | |
| src, targets = batch | |
| # Update the tokens we've seen (tracked for LR scheduling) | |
| self._tokens_seen += (src >= 0).numel() | |
| # same action as inference | |
| logits = self(src) | |
| # if we are given some desired targets also calculate the loss | |
| loss = None | |
| if targets is not None: | |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) | |
| self.logger.log_metrics( | |
| { | |
| "train_loss": loss.mean(), | |
| "learning_rate": self.lr_schedulers().get_last_lr()[0], | |
| }, | |
| step=trainer.global_step, | |
| ) | |
| return loss | |
| class CharDataset(Dataset): | |
| def __init__(self, data, block_size): | |
| chars = list(set(data)) | |
| data_size, vocab_size = len(data), len(chars) | |
| rank_zero_info("data has %d characters, %d unique." % (data_size, vocab_size)) | |
| self.stoi = {ch: i for i, ch in enumerate(chars)} | |
| self.itos = {i: ch for i, ch in enumerate(chars)} | |
| self.block_size = block_size | |
| self.vocab_size = vocab_size | |
| self.data = data | |
| def __len__(self): | |
| return len(self.data) - self.block_size | |
| def __getitem__(self, i): | |
| chunk = self.data[i : i + self.block_size + 1] | |
| dix = [self.stoi[s] for s in chunk] | |
| # src and target are off by one, we want the model to predict the next word | |
| x = torch.tensor(dix[:-1], dtype=torch.long) | |
| y = torch.tensor(dix[1:], dtype=torch.long) | |
| return x, y | |
| def to_tokens(self, message, device): | |
| return torch.tensor([self.stoi[s] for s in message], dtype=torch.long)[ | |
| None, ... | |
| ].to(device) | |
| def from_tokens(self, tokens): | |
| return "".join([self.itos[int(i)] for i in tokens]) | |
| def sample(model, x, steps, temperature=1.0, sample=False, top_k=None): | |
| """ | |
| take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in | |
| the sequence, feeding the predictions back into the model each time. Clearly the sampling | |
| has quadratic complexity unlike an RNN that is only linear, and has a finite context window | |
| of block_size, unlike an RNN that has an infinite context window. | |
| """ | |
| block_size = model.get_block_size() | |
| model.eval() | |
| # CREDITS: https://github.com/karpathy/minGPT/blob/master/mingpt/utils.py | |
| def top_k_logits(logits, k): | |
| v, _ = torch.topk(logits, k) | |
| out = logits.clone() | |
| out[out < v[:, [-1]]] = -float("Inf") | |
| return out | |
| for _ in range(steps): | |
| x_cond = ( | |
| x if x.size(1) <= block_size else x[:, -block_size:] | |
| ) # crop context if needed | |
| logits = model(x_cond) | |
| # pluck the logits at the final step and scale by temperature | |
| logits = logits[:, -1, :] / temperature | |
| # optionally crop probabilities to only the top k options | |
| if top_k is not None: | |
| logits = top_k_logits(logits, top_k) | |
| # apply softmax to convert to probabilities | |
| probs = F.softmax(logits, dim=-1) | |
| # sample from the distribution or take the most likely | |
| if sample: | |
| ix = torch.multinomial(probs, num_samples=1) | |
| else: | |
| _, ix = torch.topk(probs, k=1, dim=-1) | |
| # append to the sequence and continue | |
| x = torch.cat((x, ix), dim=1) | |
| return x[0] # escape the batch dimension | |
| if __name__ == "__main__": | |
| seed_everything(42) | |
| # Adjust batch depending on the available memory on your machine. | |
| # You can also use reversible layers to save memory | |
| REF_BATCH = 512 | |
| BATCH = 128 | |
| WORKERS = 4 | |
| EPOCHS = 1 | |
| BLOCK = 128 | |
| WARMUP = 20 | |
| if not os.path.exists("input.txt"): | |
| os.system( | |
| "wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" | |
| ) | |
| text = open("input.txt", "r").read() | |
| train_dataset = CharDataset( | |
| text, BLOCK | |
| ) # one line of poem is roughly 50 characters | |
| random_sampler = RandomSampler(train_dataset) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| sampler=random_sampler, | |
| batch_size=BATCH, | |
| num_workers=WORKERS, | |
| pin_memory=True, | |
| ) | |
| model = GPT( | |
| vocab_size=train_dataset.vocab_size, | |
| block_size=train_dataset.block_size, | |
| attention="scaled_dot_product", | |
| warmup_tokens=REF_BATCH * WARMUP, | |
| final_tokens=EPOCHS * len(train_dataset) * BLOCK, | |
| ) | |
| print(model) | |
| trainer = Trainer( | |
| gpusdevices=1, | |
| accelerator="gpu", | |
| max_epochs=EPOCHS, | |
| precision=16, | |
| log_every_n_steps=1, | |
| accumulate_grad_batches=REF_BATCH // BATCH, | |
| ) | |
| trainer.fit(model, train_loader) | |
| # Sample from the model, let it predict a paragraph | |
| context = "Friends of my soul" # prime with something | |
| x = train_dataset.to_tokens(context, model.device) | |
| y = sample(model, x, steps=1000, temperature=1.0, sample=True, top_k=10) | |
| print(train_dataset.from_tokens(y)) | |