|
|
""" |
|
|
GPT model: |
|
|
- the initial stem consists of a combination of token encoding and a positional encoding |
|
|
- the meat of it is a uniform sequence of Transformer blocks |
|
|
- each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block |
|
|
- all blocks feed into a central residual pathway similar to resnets |
|
|
- the final decoder is a linear projection into a vanilla Softmax classifier |
|
|
""" |
|
|
|
|
|
import math |
|
|
import logging |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn import functional as F |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class GPTConfig: |
|
|
""" base GPT config, params common to all GPT versions """ |
|
|
embd_pdrop = 0.1 |
|
|
resid_pdrop = 0.1 |
|
|
attn_pdrop = 0.1 |
|
|
|
|
|
def __init__(self, vocab_size, block_size, **kwargs): |
|
|
self.vocab_size = vocab_size |
|
|
self.block_size = block_size |
|
|
for k,v in kwargs.items(): |
|
|
setattr(self, k, v) |
|
|
|
|
|
class GPT1Config(GPTConfig): |
|
|
""" GPT-1 like network roughly 125M params """ |
|
|
n_layer = 12 |
|
|
n_head = 12 |
|
|
n_embd = 768 |
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
"""Root Mean Square Layer Normalization. |
|
|
|
|
|
Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License: |
|
|
https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE. |
|
|
""" |
|
|
|
|
|
def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None: |
|
|
super().__init__() |
|
|
self.scale = nn.Parameter(torch.ones(size)) |
|
|
self.eps = eps |
|
|
self.dim = dim |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
norm_x = x.to(torch.float32).pow(2).mean(dim=self.dim, keepdim=True) |
|
|
x_normed = x * torch.rsqrt(norm_x + self.eps) |
|
|
return (self.scale * x_normed).type_as(x) |
|
|
|
|
|
class CausalSelfAttention(nn.Module): |
|
|
""" |
|
|
A vanilla multi-head masked self-attention layer with a projection at the end. |
|
|
It is possible to use torch.nn.MultiheadAttention here but I am including an |
|
|
explicit implementation here to show that there is nothing too scary here. |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
assert config.n_embd % config.n_head == 0 |
|
|
|
|
|
self.key = nn.Linear(config.n_embd, config.n_embd) |
|
|
self.query = nn.Linear(config.n_embd, config.n_embd) |
|
|
self.value = nn.Linear(config.n_embd, config.n_embd) |
|
|
self.q_proj = nn.Linear( |
|
|
config.n_embd , |
|
|
config.n_embd , |
|
|
bias=False, |
|
|
) |
|
|
|
|
|
self.kv_proj = nn.Linear( |
|
|
config.n_embd , |
|
|
2 * config.n_embd , |
|
|
bias=False, |
|
|
) |
|
|
|
|
|
self.c_proj = nn.Linear( |
|
|
config.n_embd , |
|
|
config.n_embd , |
|
|
bias=False, |
|
|
) |
|
|
|
|
|
self.attn_drop = nn.Dropout(config.attn_pdrop) |
|
|
self.resid_drop = nn.Dropout(config.resid_pdrop) |
|
|
|
|
|
self.proj = nn.Linear(config.n_embd, config.n_embd) |
|
|
|
|
|
num = int(bool(config.num_props)) + int(config.scaffold_maxlen) |
|
|
|
|
|
self.register_buffer("mask", torch.tril(torch.ones(config.block_size + num, config.block_size + num)) |
|
|
.view(1, 1, config.block_size + num, config.block_size + num)) |
|
|
|
|
|
self.n_head = config.n_head |
|
|
self.n_embd = config.n_embd |
|
|
|
|
|
def forward(self, x, layer_past=None): |
|
|
B, T, C = x.size() |
|
|
|
|
|
q = self.q_proj(x) |
|
|
k, v = self.kv_proj(x).split(self.n_embd, dim=2) |
|
|
|
|
|
|
|
|
k = k.view(B, T, self.n_head, C // self.n_head).transpose( |
|
|
1, 2 |
|
|
) |
|
|
q = q.view(B, T, self.n_head, C // self.n_head).transpose( |
|
|
1, 2 |
|
|
) |
|
|
v = v.view(B, T, self.n_head, C // self.n_head).transpose( |
|
|
1, 2 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) |
|
|
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) |
|
|
att = F.softmax(att, dim=-1) |
|
|
attn_save = att |
|
|
att = self.attn_drop(att) |
|
|
y = att @ v |
|
|
y = y.transpose(1, 2).contiguous().view(B, T, C) |
|
|
|
|
|
|
|
|
y = self.c_proj(y) |
|
|
|
|
|
return y, attn_save |
|
|
|
|
|
def find_multiple(n , k ) : |
|
|
if n % k == 0: |
|
|
return n |
|
|
return n + k - (n % k) |
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
def __init__(self, config ) : |
|
|
super().__init__() |
|
|
hidden_dim = 4 * config.n_embd * config.n_head |
|
|
n_hidden = int(2 * hidden_dim / 3) |
|
|
n_hidden = find_multiple(n_hidden, 256) |
|
|
|
|
|
self.c_fc1 = nn.Linear( |
|
|
config.n_embd , n_hidden, bias=False |
|
|
) |
|
|
self.c_fc2 = nn.Linear( |
|
|
config.n_embd , n_hidden, bias=False |
|
|
) |
|
|
self.c_proj = nn.Linear( |
|
|
n_hidden, config.n_embd , bias=False |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = F.silu(self.c_fc1(x)) * self.c_fc2(x) |
|
|
x = self.c_proj(x) |
|
|
return x |
|
|
|
|
|
class Block(nn.Module): |
|
|
""" an unassuming Transformer block """ |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.rms_1 = RMSNorm(config.n_embd ) |
|
|
self.rms_2 = RMSNorm(config.n_embd ) |
|
|
self.ln1 = nn.LayerNorm(config.n_embd) |
|
|
self.ln2 = nn.LayerNorm(config.n_embd) |
|
|
self.attn = CausalSelfAttention(config) |
|
|
self.mlp = MLP(config) |
|
|
def forward(self, x): |
|
|
y, attn = self.attn(self.rms_1(x)) |
|
|
x = x + y |
|
|
x = x + self.mlp(self.rms_2(x)) |
|
|
return x, attn |
|
|
|
|
|
class GPT(nn.Module): |
|
|
""" the full GPT language model, with a context size of block_size """ |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.config = config |
|
|
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) |
|
|
self.type_emb = nn.Embedding(2, config.n_embd) |
|
|
if config.num_props: |
|
|
self.prop_nn = nn.Linear(config.num_props, config.n_embd) |
|
|
|
|
|
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) |
|
|
self.drop = nn.Dropout(config.embd_pdrop) |
|
|
|
|
|
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) |
|
|
|
|
|
self.ln_f = RMSNorm(config.n_embd ) |
|
|
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
|
|
|
|
|
self.block_size = config.block_size |
|
|
|
|
|
if config.lstm: |
|
|
self.lstm = nn.LSTM(input_size = config.n_embd, hidden_size = config.n_embd, num_layers = config.lstm_layers, dropout = 0.3, bidirectional = False) |
|
|
self.apply(self._init_weights) |
|
|
|
|
|
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) |
|
|
|
|
|
def get_block_size(self): |
|
|
return self.block_size |
|
|
|
|
|
def _init_weights(self, module): |
|
|
if isinstance(module, nn.Linear): |
|
|
torch.nn.init.normal_( |
|
|
module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer) |
|
|
) |
|
|
elif isinstance(module, nn.Embedding): |
|
|
torch.nn.init.normal_( |
|
|
module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer) |
|
|
) |
|
|
|
|
|
def configure_optimizers(self, parameters, train_config): |
|
|
|
|
|
optimizer = torch.optim.AdamW(parameters, lr=train_config.learning_rate, betas=train_config.betas) |
|
|
return optimizer |
|
|
|
|
|
def forward(self, idx, targets=None, prop = None, scaffold = None): |
|
|
b, t = idx.size() |
|
|
|
|
|
assert t <= self.block_size, "Cannot forward, model block size is exhausted." |
|
|
|
|
|
if self.config.num_props: |
|
|
assert prop.size(-1) == self.config.num_props, "Num_props should be equal to last dim of property vector" |
|
|
|
|
|
|
|
|
token_embeddings = self.tok_emb(idx) |
|
|
position_embeddings = self.pos_emb[:, :t, :] |
|
|
type_embeddings = self.type_emb(torch.ones(( |
|
|
b,t), dtype = torch.long, device = idx.device)) |
|
|
x = self.drop(token_embeddings + position_embeddings + type_embeddings) |
|
|
|
|
|
if self.config.num_props: |
|
|
type_embd = self.type_emb(torch.zeros((b, 1), dtype = torch.long, device = idx.device)) |
|
|
if prop.ndim == 2: |
|
|
p = self.prop_nn(prop.unsqueeze(1)) |
|
|
else: |
|
|
p = self.prop_nn(prop) |
|
|
p += type_embd |
|
|
x = torch.cat([p, x], 1) |
|
|
|
|
|
if self.config.scaffold: |
|
|
type_embd = self.type_emb(torch.zeros((b, 1), dtype = torch.long, device = idx.device)) |
|
|
|
|
|
scaffold_embeds = self.tok_emb(scaffold) |
|
|
if self.config.lstm: |
|
|
scaffold_embeds = self.lstm(scaffold_embeds.permute(1,0,2))[1][0] |
|
|
|
|
|
scaffold_embeds = scaffold_embeds.permute(1,0,2) |
|
|
|
|
|
|
|
|
scaffold_embeds += type_embd |
|
|
x = torch.cat([scaffold_embeds, x], 1) |
|
|
|
|
|
|
|
|
attn_maps = [] |
|
|
|
|
|
for layer in self.blocks: |
|
|
x, attn = layer(x) |
|
|
attn_maps.append(attn) |
|
|
|
|
|
x = self.ln_f(x) |
|
|
logits = self.head(x) |
|
|
|
|
|
|
|
|
if self.config.num_props and self.config.scaffold: |
|
|
num = int(bool(self.config.num_props)) + int(self.config.scaffold_maxlen) |
|
|
elif self.config.num_props: |
|
|
num = int(bool(self.config.num_props)) |
|
|
elif self.config.scaffold: |
|
|
num = int(self.config.scaffold_maxlen) |
|
|
else: |
|
|
num = 0 |
|
|
|
|
|
logits = logits[:, num:, :] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss = None |
|
|
if targets is not None: |
|
|
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.view(-1)) |
|
|
|
|
|
return logits, loss, attn_maps |