import torch import torch.nn as nn from torch.nn import functional as F import math class GPT2Config: def __init__(self): self.vocab_size = 50257 self.n_positions = 1024 self.n_ctx = 1024 self.n_embd = 768 self.n_layer = 12 self.n_head = 12 self.resid_pdrop = 0.1 self.embd_pdrop = 0.1 self.attn_pdrop = 0.1 self.layer_norm_epsilon = 1e-5 self.initializer_range = 0. self.output_attentions = False class GPT2Model(nn.Module): def __init__(self, config): super(GPT2Model, self).__init__() self.wte = nn.Embedding(config.vocab_size, config.n_embd) self.wpe = nn.Embedding(config.n_positions, config.n_embd) self.drop = nn.Dropout(config.embd_pdrop) self.h = nn.ModuleList([GPT2Block(config) for _ in range(config.n_layer)]) self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) def forward(self, input_ids, position_ids=None, past=None): if past is None: past_length = 0 past = [None] * len(self.h) else: past_length = past[0][0].size(-2) if position_ids is None: position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) position_ids = position_ids.view(-1, input_shape[-1]) inputs_embeds = self.wte(input_ids) position_embeds = self.wpe(position_ids) if inputs_embeds.size() != position_embeds.size(): raise ValueError("the embeddings of inputs and position are not the same size") hidden_states = inputs_embeds + position_embeds hidden_states = self.drop(hidden_states) presents = () for block, past in zip(self.h, past): hidden_states, present = block(hidden_states, past=past) presents = presents + (present,) hidden_states = self.ln_f(hidden_states) output_shape = input_shape + (hidden_states.size(-1),) return hidden_states.view(*output_shape), presents class GPT2Block(nn.Module): def __init__(self, config): super(GPT2Block, self).__init__() nx = config.n_embd self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) self.attn = GPT2Attention(config) self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) self.mlp = GPT2MLP(config) def forward(self, x, past): a, present = self.attn(self.ln_1(x), past=past) x = x + a m = self.mlp(self.ln_2(x)) x = x + m return x, present class GPT2Attention(nn.Module): def __init__(self, config): super(GPT2Attention, self).__init__() self.output_attentions = config.output_attentions self.n_head = config.n_head self.split_size = config.n_embd self.scale = self.split_size ** -0.5 self.c_attn = Conv1D(3 * self.split_size, self.split_size) self.c_proj = Conv1D(self.split_size, self.split_size) self.attn_dropout = nn.Dropout(config.attn_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop) self.bias = nn.Parameter(torch.zeros(1, 1, 1024, 1024)) def _attn(self, q, k, v): w = torch.matmul(q, k) if self.scale: w = w / math.sqrt(v.size(-1)) w = w.softmax(dim=-1) w = self.attn_dropout(w) return torch.matmul(w, v) def merge_heads(self, x): x = x.permute(0, 2, 1, 3).contiguous() new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) return x.view(*new_x_shape) def split_heads(self, x, k=False): new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) x = x.view(*new_x_shape) if k: return x.permute(0, 2, 3, 1) else: return x.permute(0, 2, 1, 3) def forward(self, x, past): x = self.c_attn(x) query, key, value = x.split(self.split_size, dim=2) query = self.split_heads(query) key = self.split_heads(key, k=True) value = self.split_heads(value) if past is not None: past_key, past_value = past key = torch.cat((past_key, key), dim=-1) value = torch.cat((past_value, value), dim=-1) present = (key, value) attn_output = self._attn(query, key, value) attn_output = self.merge_heads(attn_output) attn_output = self.c_proj(attn_output) # Apply bias attn_output += self.bias attn_output = self.resid_dropout(attn_output) return attn_output, present class GPT2MLP(nn.Module): def __init__(self, config): super(GPT2MLP, self).__init__() self.c_fc = Conv1D(4 * config.n_embd, config.n_embd) self.c_proj = Conv1D(config.n_embd, 4 * config.n_embd) self.act = F.gelu self.dropout = nn.Dropout(config.resid_pdrop) def forward(self, x): h = self.act(self.c_fc(x)) h2 = self.c_proj(h) return self.dropout(h2) class Conv1D(nn.Module): def __init__(self, nf, nx): super(Conv1D, self).__init__() self.nf = nf w = torch.empty(nx, nf) nn.init.normal_(w, std=0.02) self.weight = nn.Parameter(w) self.bias = nn.Parameter(torch.zeros(nf)) def forward(self, x): size_out = x.size()[:-1] + (self.nf,) x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) x = x.view(*size_out) return x config = GPT2Config() model = GPT2Model(config) state_dict = torch.load(r'c:\tmp\SI\gpt2\pytorch_model.bin', map_location=torch.device('cpu')) model.load_state_dict(state_dict) print(model)