gunasekar commited on
Commit
3df5337
·
1 Parent(s): 792a323

Upload model.py

Browse files
Files changed (1) hide show
  1. gpt2/model.py +148 -0
gpt2/model.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ import math
5
+ class GPT2Config:
6
+ def __init__(self):
7
+ self.vocab_size = 50257
8
+ self.n_positions = 1024
9
+ self.n_ctx = 1024
10
+ self.n_embd = 768
11
+ self.n_layer = 12
12
+ self.n_head = 12
13
+ self.resid_pdrop = 0.1
14
+ self.embd_pdrop = 0.1
15
+ self.attn_pdrop = 0.1
16
+ self.layer_norm_epsilon = 1e-5
17
+ self.initializer_range = 0.
18
+ self.output_attentions = False
19
+ class GPT2Model(nn.Module):
20
+ def __init__(self, config):
21
+ super(GPT2Model, self).__init__()
22
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
23
+ self.wpe = nn.Embedding(config.n_positions, config.n_embd)
24
+ self.drop = nn.Dropout(config.embd_pdrop)
25
+ self.h = nn.ModuleList([GPT2Block(config) for _ in range(config.n_layer)])
26
+ self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
27
+ def forward(self, input_ids, position_ids=None, past=None):
28
+ if past is None:
29
+ past_length = 0
30
+ past = [None] * len(self.h)
31
+ else:
32
+ past_length = past[0][0].size(-2)
33
+ if position_ids is None:
34
+ position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
35
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
36
+ input_shape = input_ids.size()
37
+ input_ids = input_ids.view(-1, input_shape[-1])
38
+ position_ids = position_ids.view(-1, input_shape[-1])
39
+ inputs_embeds = self.wte(input_ids)
40
+ position_embeds = self.wpe(position_ids)
41
+ if inputs_embeds.size() != position_embeds.size():
42
+ raise ValueError("the embeddings of inputs and position are not the same size")
43
+ hidden_states = inputs_embeds + position_embeds
44
+ hidden_states = self.drop(hidden_states)
45
+ presents = ()
46
+ for block, past in zip(self.h, past):
47
+ hidden_states, present = block(hidden_states, past=past)
48
+ presents = presents + (present,)
49
+ hidden_states = self.ln_f(hidden_states)
50
+ output_shape = input_shape + (hidden_states.size(-1),)
51
+ return hidden_states.view(*output_shape), presents
52
+ class GPT2Block(nn.Module):
53
+ def __init__(self, config):
54
+ super(GPT2Block, self).__init__()
55
+ nx = config.n_embd
56
+ self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
57
+ self.attn = GPT2Attention(config)
58
+ self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
59
+ self.mlp = GPT2MLP(config)
60
+ def forward(self, x, past):
61
+ a, present = self.attn(self.ln_1(x), past=past)
62
+ x = x + a
63
+ m = self.mlp(self.ln_2(x))
64
+ x = x + m
65
+ return x, present
66
+ class GPT2Attention(nn.Module):
67
+ def __init__(self, config):
68
+ super(GPT2Attention, self).__init__()
69
+ self.output_attentions = config.output_attentions
70
+ self.n_head = config.n_head
71
+ self.split_size = config.n_embd
72
+ self.scale = self.split_size ** -0.5
73
+ self.c_attn = Conv1D(3 * self.split_size, self.split_size)
74
+ self.c_proj = Conv1D(self.split_size, self.split_size)
75
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
76
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
77
+ self.bias = nn.Parameter(torch.zeros(1, 1, 1024, 1024))
78
+ def _attn(self, q, k, v):
79
+ w = torch.matmul(q, k)
80
+ if self.scale:
81
+ w = w / math.sqrt(v.size(-1))
82
+ w = w.softmax(dim=-1)
83
+ w = self.attn_dropout(w)
84
+ return torch.matmul(w, v)
85
+ def merge_heads(self, x):
86
+ x = x.permute(0, 2, 1, 3).contiguous()
87
+ new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
88
+ return x.view(*new_x_shape)
89
+ def split_heads(self, x, k=False):
90
+ new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
91
+ x = x.view(*new_x_shape)
92
+ if k:
93
+ return x.permute(0, 2, 3, 1)
94
+ else:
95
+ return x.permute(0, 2, 1, 3)
96
+ def forward(self, x, past):
97
+ x = self.c_attn(x)
98
+ query, key, value = x.split(self.split_size, dim=2)
99
+ query = self.split_heads(query)
100
+ key = self.split_heads(key, k=True)
101
+ value = self.split_heads(value)
102
+ if past is not None:
103
+ past_key, past_value = past
104
+ key = torch.cat((past_key, key), dim=-1)
105
+ value = torch.cat((past_value, value), dim=-1)
106
+ present = (key, value)
107
+ attn_output = self._attn(query, key, value)
108
+ attn_output = self.merge_heads(attn_output)
109
+ attn_output = self.c_proj(attn_output)
110
+ # Apply bias
111
+ attn_output += self.bias
112
+ attn_output = self.resid_dropout(attn_output)
113
+ return attn_output, present
114
+ class GPT2MLP(nn.Module):
115
+ def __init__(self, config):
116
+ super(GPT2MLP, self).__init__()
117
+ self.c_fc = Conv1D(4 * config.n_embd, config.n_embd)
118
+ self.c_proj = Conv1D(config.n_embd, 4 * config.n_embd)
119
+ self.act = F.gelu
120
+ self.dropout = nn.Dropout(config.resid_pdrop)
121
+ def forward(self, x):
122
+ h = self.act(self.c_fc(x))
123
+ h2 = self.c_proj(h)
124
+ return self.dropout(h2)
125
+ class Conv1D(nn.Module):
126
+ def __init__(self, nf, nx):
127
+ super(Conv1D, self).__init__()
128
+ self.nf = nf
129
+ w = torch.empty(nx, nf)
130
+ nn.init.normal_(w, std=0.02)
131
+ self.weight = nn.Parameter(w)
132
+ self.bias = nn.Parameter(torch.zeros(nf))
133
+
134
+ def forward(self, x):
135
+ size_out = x.size()[:-1] + (self.nf,)
136
+ x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
137
+ x = x.view(*size_out)
138
+ return x
139
+
140
+ config = GPT2Config()
141
+
142
+ model = GPT2Model(config)
143
+
144
+ state_dict = torch.load(r'c:\tmp\SI\gpt2\pytorch_model.bin', map_location=torch.device('cpu'))
145
+
146
+ model.load_state_dict(state_dict)
147
+
148
+ print(model)