gunasekar
commited on
Commit
·
3df5337
1
Parent(s):
792a323
Upload model.py
Browse files- gpt2/model.py +148 -0
gpt2/model.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
import math
|
| 5 |
+
class GPT2Config:
|
| 6 |
+
def __init__(self):
|
| 7 |
+
self.vocab_size = 50257
|
| 8 |
+
self.n_positions = 1024
|
| 9 |
+
self.n_ctx = 1024
|
| 10 |
+
self.n_embd = 768
|
| 11 |
+
self.n_layer = 12
|
| 12 |
+
self.n_head = 12
|
| 13 |
+
self.resid_pdrop = 0.1
|
| 14 |
+
self.embd_pdrop = 0.1
|
| 15 |
+
self.attn_pdrop = 0.1
|
| 16 |
+
self.layer_norm_epsilon = 1e-5
|
| 17 |
+
self.initializer_range = 0.
|
| 18 |
+
self.output_attentions = False
|
| 19 |
+
class GPT2Model(nn.Module):
|
| 20 |
+
def __init__(self, config):
|
| 21 |
+
super(GPT2Model, self).__init__()
|
| 22 |
+
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
| 23 |
+
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
|
| 24 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
| 25 |
+
self.h = nn.ModuleList([GPT2Block(config) for _ in range(config.n_layer)])
|
| 26 |
+
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
| 27 |
+
def forward(self, input_ids, position_ids=None, past=None):
|
| 28 |
+
if past is None:
|
| 29 |
+
past_length = 0
|
| 30 |
+
past = [None] * len(self.h)
|
| 31 |
+
else:
|
| 32 |
+
past_length = past[0][0].size(-2)
|
| 33 |
+
if position_ids is None:
|
| 34 |
+
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
|
| 35 |
+
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
| 36 |
+
input_shape = input_ids.size()
|
| 37 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
| 38 |
+
position_ids = position_ids.view(-1, input_shape[-1])
|
| 39 |
+
inputs_embeds = self.wte(input_ids)
|
| 40 |
+
position_embeds = self.wpe(position_ids)
|
| 41 |
+
if inputs_embeds.size() != position_embeds.size():
|
| 42 |
+
raise ValueError("the embeddings of inputs and position are not the same size")
|
| 43 |
+
hidden_states = inputs_embeds + position_embeds
|
| 44 |
+
hidden_states = self.drop(hidden_states)
|
| 45 |
+
presents = ()
|
| 46 |
+
for block, past in zip(self.h, past):
|
| 47 |
+
hidden_states, present = block(hidden_states, past=past)
|
| 48 |
+
presents = presents + (present,)
|
| 49 |
+
hidden_states = self.ln_f(hidden_states)
|
| 50 |
+
output_shape = input_shape + (hidden_states.size(-1),)
|
| 51 |
+
return hidden_states.view(*output_shape), presents
|
| 52 |
+
class GPT2Block(nn.Module):
|
| 53 |
+
def __init__(self, config):
|
| 54 |
+
super(GPT2Block, self).__init__()
|
| 55 |
+
nx = config.n_embd
|
| 56 |
+
self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
|
| 57 |
+
self.attn = GPT2Attention(config)
|
| 58 |
+
self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
|
| 59 |
+
self.mlp = GPT2MLP(config)
|
| 60 |
+
def forward(self, x, past):
|
| 61 |
+
a, present = self.attn(self.ln_1(x), past=past)
|
| 62 |
+
x = x + a
|
| 63 |
+
m = self.mlp(self.ln_2(x))
|
| 64 |
+
x = x + m
|
| 65 |
+
return x, present
|
| 66 |
+
class GPT2Attention(nn.Module):
|
| 67 |
+
def __init__(self, config):
|
| 68 |
+
super(GPT2Attention, self).__init__()
|
| 69 |
+
self.output_attentions = config.output_attentions
|
| 70 |
+
self.n_head = config.n_head
|
| 71 |
+
self.split_size = config.n_embd
|
| 72 |
+
self.scale = self.split_size ** -0.5
|
| 73 |
+
self.c_attn = Conv1D(3 * self.split_size, self.split_size)
|
| 74 |
+
self.c_proj = Conv1D(self.split_size, self.split_size)
|
| 75 |
+
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
| 76 |
+
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
| 77 |
+
self.bias = nn.Parameter(torch.zeros(1, 1, 1024, 1024))
|
| 78 |
+
def _attn(self, q, k, v):
|
| 79 |
+
w = torch.matmul(q, k)
|
| 80 |
+
if self.scale:
|
| 81 |
+
w = w / math.sqrt(v.size(-1))
|
| 82 |
+
w = w.softmax(dim=-1)
|
| 83 |
+
w = self.attn_dropout(w)
|
| 84 |
+
return torch.matmul(w, v)
|
| 85 |
+
def merge_heads(self, x):
|
| 86 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
| 87 |
+
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
|
| 88 |
+
return x.view(*new_x_shape)
|
| 89 |
+
def split_heads(self, x, k=False):
|
| 90 |
+
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
|
| 91 |
+
x = x.view(*new_x_shape)
|
| 92 |
+
if k:
|
| 93 |
+
return x.permute(0, 2, 3, 1)
|
| 94 |
+
else:
|
| 95 |
+
return x.permute(0, 2, 1, 3)
|
| 96 |
+
def forward(self, x, past):
|
| 97 |
+
x = self.c_attn(x)
|
| 98 |
+
query, key, value = x.split(self.split_size, dim=2)
|
| 99 |
+
query = self.split_heads(query)
|
| 100 |
+
key = self.split_heads(key, k=True)
|
| 101 |
+
value = self.split_heads(value)
|
| 102 |
+
if past is not None:
|
| 103 |
+
past_key, past_value = past
|
| 104 |
+
key = torch.cat((past_key, key), dim=-1)
|
| 105 |
+
value = torch.cat((past_value, value), dim=-1)
|
| 106 |
+
present = (key, value)
|
| 107 |
+
attn_output = self._attn(query, key, value)
|
| 108 |
+
attn_output = self.merge_heads(attn_output)
|
| 109 |
+
attn_output = self.c_proj(attn_output)
|
| 110 |
+
# Apply bias
|
| 111 |
+
attn_output += self.bias
|
| 112 |
+
attn_output = self.resid_dropout(attn_output)
|
| 113 |
+
return attn_output, present
|
| 114 |
+
class GPT2MLP(nn.Module):
|
| 115 |
+
def __init__(self, config):
|
| 116 |
+
super(GPT2MLP, self).__init__()
|
| 117 |
+
self.c_fc = Conv1D(4 * config.n_embd, config.n_embd)
|
| 118 |
+
self.c_proj = Conv1D(config.n_embd, 4 * config.n_embd)
|
| 119 |
+
self.act = F.gelu
|
| 120 |
+
self.dropout = nn.Dropout(config.resid_pdrop)
|
| 121 |
+
def forward(self, x):
|
| 122 |
+
h = self.act(self.c_fc(x))
|
| 123 |
+
h2 = self.c_proj(h)
|
| 124 |
+
return self.dropout(h2)
|
| 125 |
+
class Conv1D(nn.Module):
|
| 126 |
+
def __init__(self, nf, nx):
|
| 127 |
+
super(Conv1D, self).__init__()
|
| 128 |
+
self.nf = nf
|
| 129 |
+
w = torch.empty(nx, nf)
|
| 130 |
+
nn.init.normal_(w, std=0.02)
|
| 131 |
+
self.weight = nn.Parameter(w)
|
| 132 |
+
self.bias = nn.Parameter(torch.zeros(nf))
|
| 133 |
+
|
| 134 |
+
def forward(self, x):
|
| 135 |
+
size_out = x.size()[:-1] + (self.nf,)
|
| 136 |
+
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
|
| 137 |
+
x = x.view(*size_out)
|
| 138 |
+
return x
|
| 139 |
+
|
| 140 |
+
config = GPT2Config()
|
| 141 |
+
|
| 142 |
+
model = GPT2Model(config)
|
| 143 |
+
|
| 144 |
+
state_dict = torch.load(r'c:\tmp\SI\gpt2\pytorch_model.bin', map_location=torch.device('cpu'))
|
| 145 |
+
|
| 146 |
+
model.load_state_dict(state_dict)
|
| 147 |
+
|
| 148 |
+
print(model)
|