File size: 5,909 Bytes
3df5337
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
class GPT2Config:
    def __init__(self):
        self.vocab_size = 50257
        self.n_positions = 1024
        self.n_ctx = 1024
        self.n_embd = 768
        self.n_layer = 12
        self.n_head = 12
        self.resid_pdrop = 0.1
        self.embd_pdrop = 0.1
        self.attn_pdrop = 0.1
        self.layer_norm_epsilon = 1e-5
        self.initializer_range = 0.
        self.output_attentions = False          
class GPT2Model(nn.Module):
    def __init__(self, config):
        super(GPT2Model, self).__init__()
        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        self.wpe = nn.Embedding(config.n_positions, config.n_embd)
        self.drop = nn.Dropout(config.embd_pdrop)
        self.h = nn.ModuleList([GPT2Block(config) for _ in range(config.n_layer)])
        self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
    def forward(self, input_ids, position_ids=None, past=None):
        if past is None:
            past_length = 0
            past = [None] * len(self.h)
        else:
            past_length = past[0][0].size(-2)
        if position_ids is None:
            position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
            position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        input_shape = input_ids.size()
        input_ids = input_ids.view(-1, input_shape[-1])
        position_ids = position_ids.view(-1, input_shape[-1])
        inputs_embeds = self.wte(input_ids)
        position_embeds = self.wpe(position_ids)
        if inputs_embeds.size() != position_embeds.size():
            raise ValueError("the embeddings of inputs and position are not the same size")
        hidden_states = inputs_embeds + position_embeds
        hidden_states = self.drop(hidden_states)
        presents = ()
        for block, past in zip(self.h, past):
            hidden_states, present = block(hidden_states, past=past)
            presents = presents + (present,)
        hidden_states = self.ln_f(hidden_states)
        output_shape = input_shape + (hidden_states.size(-1),)
        return hidden_states.view(*output_shape), presents
class GPT2Block(nn.Module):
    def __init__(self, config):
        super(GPT2Block, self).__init__()
        nx = config.n_embd
        self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
        self.attn = GPT2Attention(config)
        self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
        self.mlp = GPT2MLP(config)
    def forward(self, x, past):
        a, present = self.attn(self.ln_1(x), past=past)
        x = x + a
        m = self.mlp(self.ln_2(x))
        x = x + m
        return x, present
class GPT2Attention(nn.Module):
    def __init__(self, config):
        super(GPT2Attention, self).__init__()
        self.output_attentions = config.output_attentions
        self.n_head = config.n_head
        self.split_size = config.n_embd
        self.scale = self.split_size ** -0.5
        self.c_attn = Conv1D(3 * self.split_size, self.split_size)
        self.c_proj = Conv1D(self.split_size, self.split_size)
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
        self.bias = nn.Parameter(torch.zeros(1, 1, 1024, 1024))
    def _attn(self, q, k, v): 
        w = torch.matmul(q, k) 
        if self.scale: 
            w = w / math.sqrt(v.size(-1)) 
        w = w.softmax(dim=-1) 
        w = self.attn_dropout(w) 
        return torch.matmul(w, v) 
    def merge_heads(self, x): 
        x = x.permute(0, 2, 1, 3).contiguous() 
        new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) 
        return x.view(*new_x_shape) 
    def split_heads(self, x, k=False): 
        new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) 
        x = x.view(*new_x_shape) 
        if k: 
            return x.permute(0, 2, 3, 1) 
        else: 
            return x.permute(0, 2, 1, 3)
    def forward(self, x, past):
        x = self.c_attn(x)
        query, key, value = x.split(self.split_size, dim=2)
        query = self.split_heads(query)
        key = self.split_heads(key, k=True)
        value = self.split_heads(value)
        if past is not None:
            past_key, past_value = past
            key = torch.cat((past_key, key), dim=-1)
            value = torch.cat((past_value, value), dim=-1)
        present = (key, value)
        attn_output = self._attn(query, key, value)
        attn_output = self.merge_heads(attn_output)
        attn_output = self.c_proj(attn_output)
        # Apply bias
        attn_output += self.bias
        attn_output = self.resid_dropout(attn_output)
        return attn_output, present
class GPT2MLP(nn.Module):
    def __init__(self, config):
        super(GPT2MLP, self).__init__()
        self.c_fc = Conv1D(4 * config.n_embd, config.n_embd)
        self.c_proj = Conv1D(config.n_embd, 4 * config.n_embd)
        self.act = F.gelu
        self.dropout = nn.Dropout(config.resid_pdrop)
    def forward(self, x):
        h = self.act(self.c_fc(x))
        h2 = self.c_proj(h)
        return self.dropout(h2)
class Conv1D(nn.Module):
    def __init__(self, nf, nx):
        super(Conv1D, self).__init__()
        self.nf = nf
        w = torch.empty(nx, nf)
        nn.init.normal_(w, std=0.02)
        self.weight = nn.Parameter(w)
        self.bias = nn.Parameter(torch.zeros(nf))

    def forward(self, x):
        size_out = x.size()[:-1] + (self.nf,)
        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
        x = x.view(*size_out)
        return x

config = GPT2Config()

model = GPT2Model(config)

state_dict = torch.load(r'c:\tmp\SI\gpt2\pytorch_model.bin', map_location=torch.device('cpu'))

model.load_state_dict(state_dict)

print(model)