Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| # Yan Chen 2023.10 | |
| # yanchen@xjtu.edu.com | |
| """ | |
| GPT model: | |
| - the initial stem consists of a combination of token encoding and a positional encoding | |
| - the meat of it is a uniform sequence of Transformer blocks | |
| - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block | |
| - all blocks feed into a central residual pathway similar to resnets | |
| - the final decoder is a linear projection into a vanilla Softmax classifier | |
| """ | |
| import math,json | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| class GPTConfig: | |
| """ base GPT config, params common to all GPT versions """ | |
| embd_pdrop = 0.1 | |
| resid_pdrop = 0.1 | |
| attn_pdrop = 0.1 | |
| def __init__(self, vocab_size, block_size, **kwargs): | |
| self.vocab_size = vocab_size | |
| self.block_size = block_size | |
| for k,v in kwargs.items(): | |
| setattr(self, k, v) | |
| class GPT1Config(GPTConfig): | |
| """ GPT-1 like network roughly 125M params """ | |
| n_layer = 12 | |
| n_head = 12 | |
| n_embd = 768 | |
| class CausalSelfAttention(nn.Module): | |
| """ | |
| A vanilla multi-head masked self-attention layer with a projection at the end. | |
| It is possible to use torch.nn.MultiheadAttention here but I am including an | |
| explicit implementation here to show that there is nothing too scary here. | |
| """ | |
| def __init__(self, config): | |
| super().__init__() | |
| assert config.n_embd % config.n_head == 0 | |
| # key, query, value projections for all heads | |
| self.key = nn.Linear(config.n_embd, config.n_embd) | |
| self.query = nn.Linear(config.n_embd, config.n_embd) | |
| self.value = nn.Linear(config.n_embd, config.n_embd) | |
| # regularization | |
| self.attn_drop = nn.Dropout(config.attn_pdrop) | |
| self.resid_drop = nn.Dropout(config.resid_pdrop) | |
| # output projection | |
| self.proj = nn.Linear(config.n_embd, config.n_embd) | |
| # causal mask to ensure that attention is only applied to the left in the input sequence | |
| num = int(bool(config.num_props)) | |
| # num = 1 | |
| self.register_buffer("mask", torch.tril(torch.ones(config.block_size + num, config.block_size + num)) | |
| .view(1, 1, config.block_size + num, config.block_size + num)) | |
| self.n_head = config.n_head | |
| def forward(self, x, layer_past=None): | |
| B, T, C = x.size() | |
| # calculate query, key, values for all heads in batch and move head forward to be the batch dim | |
| k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) | |
| q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) | |
| v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) | |
| # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) | |
| att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) | |
| att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) | |
| att = F.softmax(att, dim=-1) | |
| attn_save = att | |
| att = self.attn_drop(att) | |
| y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) | |
| y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side | |
| # output projection | |
| y = self.resid_drop(self.proj(y)) | |
| return y, attn_save | |
| class Block(nn.Module): | |
| """ an unassuming Transformer block """ | |
| def __init__(self, config): | |
| super().__init__() | |
| self.ln1 = nn.LayerNorm(config.n_embd) | |
| self.ln2 = nn.LayerNorm(config.n_embd) | |
| self.attn = CausalSelfAttention(config) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(config.n_embd, 4 * config.n_embd), | |
| nn.GELU(), | |
| nn.Linear(4 * config.n_embd, config.n_embd), | |
| nn.Dropout(config.resid_pdrop), | |
| ) | |
| def forward(self, x): | |
| y, attn = self.attn(self.ln1(x)) | |
| x = x + y | |
| x = x + self.mlp(self.ln2(x)) | |
| return x, attn | |
| class GPT(nn.Module): | |
| """ the full GPT language model, with a context size of block_size """ | |
| def __init__(self, config): | |
| super().__init__() | |
| #print(json.dumps(config.__dict__, indent=2)) | |
| # input embedding stem | |
| self.config = config | |
| self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) | |
| self.type_emb = nn.Embedding(2, config.n_embd) | |
| if config.num_props: | |
| self.prop_nn = nn.Linear(config.num_props, config.n_embd) | |
| self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) | |
| self.drop = nn.Dropout(config.embd_pdrop) | |
| # transformer | |
| self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) | |
| # decoder head | |
| self.ln_f = nn.LayerNorm(config.n_embd) | |
| self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) | |
| self.block_size = config.block_size | |
| if config.lstm: | |
| self.lstm = nn.LSTM(input_size = config.n_embd, hidden_size = config.n_embd, num_layers = config.lstm_layers, dropout = 0.3, bidirectional = False) | |
| self.apply(self._init_weights) | |
| #logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) | |
| def get_block_size(self): | |
| return self.block_size | |
| def _init_weights(self, module): | |
| if isinstance(module, (nn.Linear, nn.Embedding)): | |
| module.weight.data.normal_(mean=0.0, std=0.02) | |
| if isinstance(module, nn.Linear) and module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.LayerNorm): | |
| module.bias.data.zero_() | |
| module.weight.data.fill_(1.0) | |
| def configure_optimizers(self, train_config): | |
| """ | |
| This long function is unfortunately doing something very simple and is being very defensive: | |
| We are separating out all parameters of the model into two buckets: those that will experience | |
| weight decay for regularization and those that won't (biases, and layernorm/embedding weights). | |
| We are then returning the PyTorch optimizer object. | |
| """ | |
| # separate out all parameters to those that will and won't experience regularizing weight decay | |
| decay = set() | |
| no_decay = set() | |
| whitelist_weight_modules = (torch.nn.Linear, torch.nn.LSTM) | |
| blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) | |
| for mn, m in self.named_modules(): | |
| for pn, p in m.named_parameters(): | |
| fpn = '%s.%s' % (mn, pn) if mn else pn # full param name | |
| if pn.endswith('bias') or ('bias' in pn): | |
| # all biases will not be decayed | |
| no_decay.add(fpn) | |
| elif (pn.endswith('weight') or ('weight' in pn)) and isinstance(m, whitelist_weight_modules): | |
| # weights of whitelist modules will be weight decayed | |
| decay.add(fpn) | |
| elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): | |
| # weights of blacklist modules will NOT be weight decayed | |
| no_decay.add(fpn) | |
| # special case the position embedding parameter in the root GPT module as not decayed | |
| no_decay.add('pos_emb') | |
| # validate that we considered every parameter | |
| param_dict = {pn: p for pn, p in self.named_parameters()} | |
| inter_params = decay & no_decay | |
| union_params = decay | no_decay | |
| assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) | |
| assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ | |
| % (str(param_dict.keys() - union_params), ) | |
| # create the pytorch optimizer object | |
| optim_groups = [ | |
| {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay}, | |
| {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, | |
| ] | |
| optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas) | |
| return optimizer | |
| def forward(self, idx, targets=None, prop = None): | |
| b, t = idx.size() | |
| assert t <= self.block_size, "Cannot forward, model block size is exhausted." | |
| if self.config.num_props: | |
| assert prop.size(-1) == self.config.num_props, "Num_props should be equal to last dim of property vector" | |
| # forward the GPT model | |
| token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector | |
| position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector | |
| type_embeddings = self.type_emb(torch.ones((b,t), dtype = torch.long, device = idx.device)) | |
| x = self.drop(token_embeddings + position_embeddings + type_embeddings) | |
| embed = x | |
| if self.config.num_props: | |
| type_embd = self.type_emb(torch.zeros((b, 1), dtype = torch.long, device = idx.device)) | |
| if prop.ndim == 2: | |
| p = self.prop_nn(prop.unsqueeze(1)) # for single property | |
| else: | |
| p = self.prop_nn(prop) # for multiproperty | |
| p += type_embd | |
| x = torch.cat([p, x], 1) | |
| # x = self.blocks(x) | |
| attn_maps = [] | |
| for layer in self.blocks: | |
| x, attn = layer(x) | |
| attn_maps.append(attn) | |
| x = self.ln_f(x) | |
| logits = self.head(x) | |
| if self.config.num_props: | |
| num = int(bool(self.config.num_props)) | |
| else: | |
| num = 0 | |
| logits = logits[:, num:, :] | |
| # if we are given some desired targets also calculate the loss | |
| loss = None | |
| if targets is not None: | |
| loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.view(-1)) | |
| return logits, loss, attn_maps, embed # (num_layers, batch_size, num_heads, max_seq_len, max_seq_len) | |
| def sample(self, x, steps, temperature=1.0, do_sample=False, top_k=None, top_p=None, prop=None): | |
| """ | |
| Take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in | |
| the sequence, feeding the predictions back into the model each time. Clearly the sampling | |
| has quadratic complexity unlike an RNN that is only linear, and has a finite context window | |
| of block_size, unlike an RNN that has an infinite context window. | |
| Most likely you'll want to make sure to be in model.eval() mode of operation for this. | |
| """ | |
| #model.eval() | |
| def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): | |
| """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering | |
| Args: | |
| logits: logits distribution shape (batch size x vocabulary size) | |
| top_k > 0: keep only top k tokens with highest probability (top-k filtering). | |
| top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). | |
| Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) | |
| From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 | |
| """ | |
| top_k = min(top_k, logits.size(-1)) # Safety check | |
| if top_k > 0: | |
| # Remove all tokens with a probability less than the last token of the top-k | |
| indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | |
| logits[indices_to_remove] = filter_value | |
| if top_p > 0.0: | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
| # Remove tokens with cumulative probability above the threshold | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| # Shift the indices to the right to keep also the first token above the threshold | |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
| sorted_indices_to_remove[..., 0] = 0 | |
| # scatter sorted tensors to original indexing | |
| indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove) | |
| logits[indices_to_remove] = filter_value | |
| return logits | |
| for k in range(steps): | |
| x_cond = x if x.size(1) <= self.block_size else x[:, -self.block_size:] # crop context if needed | |
| # forward the model to get the logits for the index in the sequence | |
| logits, _, _, _ = self(x_cond, prop = prop) # for sampling, no target | |
| # pluck the logits at the final step and scale by desired temperature | |
| logits = logits[:, -1, :] / temperature | |
| # optionally crop the logits to only the top k options OR using nucleus (top-p) filtering | |
| #if top_k is not None: | |
| # v, _ = torch.topk(logits, top_k) | |
| # logits[logits < v[:, [-1]]] = -float('Inf') | |
| logits = top_k_top_p_filtering(logits, top_p=top_p, top_k=top_k) | |
| # apply softmax to convert logits to (normalized) probabilities | |
| probs = F.softmax(logits, dim=-1) | |
| # sample from the distribution or take the most likely | |
| if do_sample: | |
| x_next = torch.multinomial(probs, num_samples=1) | |
| else: | |
| _, x_next = torch.topk(probs, k=1, dim=-1) | |
| # append sampled index to the running sequence and continue | |
| x = torch.cat((x, x_next), dim=1) | |
| return x[:, 1:] | |