Hub / gpt2 /model.py
gunasekar
Upload model.py
3df5337
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
class GPT2Config:
def __init__(self):
self.vocab_size = 50257
self.n_positions = 1024
self.n_ctx = 1024
self.n_embd = 768
self.n_layer = 12
self.n_head = 12
self.resid_pdrop = 0.1
self.embd_pdrop = 0.1
self.attn_pdrop = 0.1
self.layer_norm_epsilon = 1e-5
self.initializer_range = 0.
self.output_attentions = False
class GPT2Model(nn.Module):
def __init__(self, config):
super(GPT2Model, self).__init__()
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([GPT2Block(config) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
def forward(self, input_ids, position_ids=None, past=None):
if past is None:
past_length = 0
past = [None] * len(self.h)
else:
past_length = past[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
position_ids = position_ids.view(-1, input_shape[-1])
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
if inputs_embeds.size() != position_embeds.size():
raise ValueError("the embeddings of inputs and position are not the same size")
hidden_states = inputs_embeds + position_embeds
hidden_states = self.drop(hidden_states)
presents = ()
for block, past in zip(self.h, past):
hidden_states, present = block(hidden_states, past=past)
presents = presents + (present,)
hidden_states = self.ln_f(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
return hidden_states.view(*output_shape), presents
class GPT2Block(nn.Module):
def __init__(self, config):
super(GPT2Block, self).__init__()
nx = config.n_embd
self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(config)
self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(config)
def forward(self, x, past):
a, present = self.attn(self.ln_1(x), past=past)
x = x + a
m = self.mlp(self.ln_2(x))
x = x + m
return x, present
class GPT2Attention(nn.Module):
def __init__(self, config):
super(GPT2Attention, self).__init__()
self.output_attentions = config.output_attentions
self.n_head = config.n_head
self.split_size = config.n_embd
self.scale = self.split_size ** -0.5
self.c_attn = Conv1D(3 * self.split_size, self.split_size)
self.c_proj = Conv1D(self.split_size, self.split_size)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.bias = nn.Parameter(torch.zeros(1, 1, 1024, 1024))
def _attn(self, q, k, v):
w = torch.matmul(q, k)
if self.scale:
w = w / math.sqrt(v.size(-1))
w = w.softmax(dim=-1)
w = self.attn_dropout(w)
return torch.matmul(w, v)
def merge_heads(self, x):
x = x.permute(0, 2, 1, 3).contiguous()
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
return x.view(*new_x_shape)
def split_heads(self, x, k=False):
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
x = x.view(*new_x_shape)
if k:
return x.permute(0, 2, 3, 1)
else:
return x.permute(0, 2, 1, 3)
def forward(self, x, past):
x = self.c_attn(x)
query, key, value = x.split(self.split_size, dim=2)
query = self.split_heads(query)
key = self.split_heads(key, k=True)
value = self.split_heads(value)
if past is not None:
past_key, past_value = past
key = torch.cat((past_key, key), dim=-1)
value = torch.cat((past_value, value), dim=-1)
present = (key, value)
attn_output = self._attn(query, key, value)
attn_output = self.merge_heads(attn_output)
attn_output = self.c_proj(attn_output)
# Apply bias
attn_output += self.bias
attn_output = self.resid_dropout(attn_output)
return attn_output, present
class GPT2MLP(nn.Module):
def __init__(self, config):
super(GPT2MLP, self).__init__()
self.c_fc = Conv1D(4 * config.n_embd, config.n_embd)
self.c_proj = Conv1D(config.n_embd, 4 * config.n_embd)
self.act = F.gelu
self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, x):
h = self.act(self.c_fc(x))
h2 = self.c_proj(h)
return self.dropout(h2)
class Conv1D(nn.Module):
def __init__(self, nf, nx):
super(Conv1D, self).__init__()
self.nf = nf
w = torch.empty(nx, nf)
nn.init.normal_(w, std=0.02)
self.weight = nn.Parameter(w)
self.bias = nn.Parameter(torch.zeros(nf))
def forward(self, x):
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(*size_out)
return x
config = GPT2Config()
model = GPT2Model(config)
state_dict = torch.load(r'c:\tmp\SI\gpt2\pytorch_model.bin', map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
print(model)