TeLLAgent / tool /comget /model.py
jinysun's picture
Upload 46 files
64e9ead verified
"""
GPT model:
- the initial stem consists of a combination of token encoding and a positional encoding
- the meat of it is a uniform sequence of Transformer blocks
- each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
- all blocks feed into a central residual pathway similar to resnets
- the final decoder is a linear projection into a vanilla Softmax classifier
"""
import math
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
logger = logging.getLogger(__name__)
class GPTConfig:
""" base GPT config, params common to all GPT versions """
embd_pdrop = 0.1
resid_pdrop = 0.1
attn_pdrop = 0.1
def __init__(self, vocab_size, block_size, **kwargs):
self.vocab_size = vocab_size
self.block_size = block_size
for k,v in kwargs.items():
setattr(self, k, v)
class GPT1Config(GPTConfig):
""" GPT-1 like network roughly 125M params """
n_layer = 12
n_head = 12
n_embd = 768
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization.
Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
"""
def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
super().__init__()
self.scale = nn.Parameter(torch.ones(size))
self.eps = eps
self.dim = dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
# NOTE: the original RMSNorm paper implementation is not equivalent
# norm_x = x.norm(2, dim=self.dim, keepdim=True)
# rms_x = norm_x * d_x ** (-1. / 2)
# x_normed = x / (rms_x + self.eps)
# keep RMSNorm in float32
norm_x = x.to(torch.float32).pow(2).mean(dim=self.dim, keepdim=True)
x_normed = x * torch.rsqrt(norm_x + self.eps)
return (self.scale * x_normed).type_as(x)
class CausalSelfAttention(nn.Module):
"""
A vanilla multi-head masked self-attention layer with a projection at the end.
It is possible to use torch.nn.MultiheadAttention here but I am including an
explicit implementation here to show that there is nothing too scary here.
"""
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
# key, query, value projections for all heads
self.key = nn.Linear(config.n_embd, config.n_embd)
self.query = nn.Linear(config.n_embd, config.n_embd)
self.value = nn.Linear(config.n_embd, config.n_embd)
self.q_proj = nn.Linear(
config.n_embd ,
config.n_embd ,
bias=False,
)
# key, value projections
self.kv_proj = nn.Linear(
config.n_embd ,
2 * config.n_embd ,
bias=False,
)
# output projection
self.c_proj = nn.Linear(
config.n_embd ,
config.n_embd ,
bias=False,
)
# regularization
self.attn_drop = nn.Dropout(config.attn_pdrop)
self.resid_drop = nn.Dropout(config.resid_pdrop)
# output projection
self.proj = nn.Linear(config.n_embd, config.n_embd)
# causal mask to ensure that attention is only applied to the left in the input sequence
num = int(bool(config.num_props)) + int(config.scaffold_maxlen) #int(config.lstm_layers) # int(config.scaffold)
# num = 1
self.register_buffer("mask", torch.tril(torch.ones(config.block_size + num, config.block_size + num))
.view(1, 1, config.block_size + num, config.block_size + num))
self.n_head = config.n_head
self.n_embd = config.n_embd
def forward(self, x, layer_past=None):
B, T, C = x.size()
q = self.q_proj(x)
k, v = self.kv_proj(x).split(self.n_embd, dim=2)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
k = k.view(B, T, self.n_head, C // self.n_head).transpose(
1, 2
) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(
1, 2
) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(
1, 2
)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
# y = F.scaled_dot_product_attention(
# q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True
# )
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
attn_save = att
att = self.attn_drop(att)
y = att @ v
y = y.transpose(1, 2).contiguous().view(B, T, C)
# output projection
y = self.c_proj(y)
return y, attn_save
def find_multiple(n , k ) :
if n % k == 0:
return n
return n + k - (n % k)
class MLP(nn.Module):
def __init__(self, config ) :
super().__init__()
hidden_dim = 4 * config.n_embd * config.n_head
n_hidden = int(2 * hidden_dim / 3)
n_hidden = find_multiple(n_hidden, 256)
self.c_fc1 = nn.Linear(
config.n_embd , n_hidden, bias=False
)
self.c_fc2 = nn.Linear(
config.n_embd , n_hidden, bias=False
)
self.c_proj = nn.Linear(
n_hidden, config.n_embd , bias=False
)
def forward(self, x):
x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
x = self.c_proj(x)
return x
class Block(nn.Module):
""" an unassuming Transformer block """
def __init__(self, config):
super().__init__()
self.rms_1 = RMSNorm(config.n_embd )
self.rms_2 = RMSNorm(config.n_embd )
self.ln1 = nn.LayerNorm(config.n_embd)
self.ln2 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.mlp = MLP(config)
def forward(self, x):
y, attn = self.attn(self.rms_1(x))
x = x + y
x = x + self.mlp(self.rms_2(x))
return x, attn
class GPT(nn.Module):
""" the full GPT language model, with a context size of block_size """
def __init__(self, config):
super().__init__()
# input embedding stem
self.config = config
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
self.type_emb = nn.Embedding(2, config.n_embd)
if config.num_props:
self.prop_nn = nn.Linear(config.num_props, config.n_embd)
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
self.drop = nn.Dropout(config.embd_pdrop)
# transformer
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
# decoder head
self.ln_f = RMSNorm(config.n_embd )
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.block_size = config.block_size
if config.lstm:
self.lstm = nn.LSTM(input_size = config.n_embd, hidden_size = config.n_embd, num_layers = config.lstm_layers, dropout = 0.3, bidirectional = False)
self.apply(self._init_weights)
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
def get_block_size(self):
return self.block_size
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(
module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer)
)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(
module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer)
)
def configure_optimizers(self, parameters, train_config):
optimizer = torch.optim.AdamW(parameters, lr=train_config.learning_rate, betas=train_config.betas)
return optimizer
def forward(self, idx, targets=None, prop = None, scaffold = None):
b, t = idx.size()
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
if self.config.num_props:
assert prop.size(-1) == self.config.num_props, "Num_props should be equal to last dim of property vector"
# forward the GPT model
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
type_embeddings = self.type_emb(torch.ones((
b,t), dtype = torch.long, device = idx.device))
x = self.drop(token_embeddings + position_embeddings + type_embeddings)
if self.config.num_props:
type_embd = self.type_emb(torch.zeros((b, 1), dtype = torch.long, device = idx.device))
if prop.ndim == 2:
p = self.prop_nn(prop.unsqueeze(1)) # for single property
else:
p = self.prop_nn(prop) # for multiproperty
p += type_embd
x = torch.cat([p, x], 1)
if self.config.scaffold:
type_embd = self.type_emb(torch.zeros((b, 1), dtype = torch.long, device = idx.device))
scaffold_embeds = self.tok_emb(scaffold) # .mean(1, keepdim = True)
if self.config.lstm:
scaffold_embeds = self.lstm(scaffold_embeds.permute(1,0,2))[1][0]
# scaffold_embeds = scaffold_embeds.reshape(scaffold_embeds.shape[1], scaffold_embeds.shape[0], 2, self.config.n_embd).mean(2)
scaffold_embeds = scaffold_embeds.permute(1,0,2) # mean(0, keepdim = True)
# scaffold_embeds = scaffold_embeds.reshape(self.config.lstm_layers, 1, -1, self.config.n_embd)[-1].permute(1,0,2)
# scaffold_embeds = scaffold_embeds.reshape(scaffold_embeds.shape[1], scaffold_embeds.shape[0], self.config.n_embd)
scaffold_embeds += type_embd
x = torch.cat([scaffold_embeds, x], 1)
# x = self.blocks(x)
attn_maps = []
for layer in self.blocks:
x, attn = layer(x)
attn_maps.append(attn)
x = self.ln_f(x)
logits = self.head(x)
if self.config.num_props and self.config.scaffold:
num = int(bool(self.config.num_props)) + int(self.config.scaffold_maxlen)
elif self.config.num_props:
num = int(bool(self.config.num_props))
elif self.config.scaffold:
num = int(self.config.scaffold_maxlen)
else:
num = 0
logits = logits[:, num:, :]
# if self.config.num_props or self.config.scaffold:
# num = int(bool(self.config.num_props)) + int(self.config.scaffold_maxlen) #int(self.config.lstm_layers) # int(self.config.scaffold) # int(self.config.scaffold)
# print(logits.shape)
# if we are given some desired targets also calculate the loss
loss = None
if targets is not None:
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.view(-1))
return logits, loss, attn_maps # (num_layers, batch_size, num_heads, max_seq_len, max_seq_len)