Hai929 commited on
Commit
d01df20
·
verified ·
1 Parent(s): fe0b42b

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +146 -0
model.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from transformers import GPT2Config
6
+ from transformers.modeling_utils import PreTrainedModel
7
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
8
+
9
+ # -------------------------------------------------
10
+ # GPT-2 Attention
11
+ # -------------------------------------------------
12
+ class GPT2Attention(nn.Module):
13
+ def __init__(self, config):
14
+ super().__init__()
15
+ self.n_head = config.n_head
16
+ self.n_embd = config.n_embd
17
+ self.head_dim = self.n_embd // self.n_head
18
+
19
+ self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd)
20
+ self.c_proj = nn.Linear(self.n_embd, self.n_embd)
21
+
22
+ self.register_buffer(
23
+ "bias",
24
+ torch.tril(torch.ones(config.n_ctx, config.n_ctx))
25
+ .view(1, 1, config.n_ctx, config.n_ctx),
26
+ persistent=False
27
+ )
28
+
29
+ def forward(self, x):
30
+ B, T, C = x.size()
31
+
32
+ qkv = self.c_attn(x)
33
+ q, k, v = qkv.split(self.n_embd, dim=2)
34
+
35
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
36
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
37
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
38
+
39
+ att = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
40
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
41
+ att = F.softmax(att, dim=-1)
42
+
43
+ y = att @ v
44
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
45
+
46
+ return self.c_proj(y)
47
+
48
+ # -------------------------------------------------
49
+ # GPT-2 MLP
50
+ # -------------------------------------------------
51
+ class GPT2MLP(nn.Module):
52
+ def __init__(self, config):
53
+ super().__init__()
54
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
55
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
56
+
57
+ def forward(self, x):
58
+ return self.c_proj(F.gelu(self.c_fc(x)))
59
+
60
+ # -------------------------------------------------
61
+ # GPT-2 Block
62
+ # -------------------------------------------------
63
+ class GPT2Block(nn.Module):
64
+ def __init__(self, config):
65
+ super().__init__()
66
+ self.ln_1 = nn.LayerNorm(config.n_embd, eps=1e-5)
67
+ self.attn = GPT2Attention(config)
68
+ self.ln_2 = nn.LayerNorm(config.n_embd, eps=1e-5)
69
+ self.mlp = GPT2MLP(config)
70
+
71
+ def forward(self, x):
72
+ x = x + self.attn(self.ln_1(x))
73
+ x = x + self.mlp(self.ln_2(x))
74
+ return x
75
+
76
+ # -------------------------------------------------
77
+ # GPT-2 Transformer
78
+ # -------------------------------------------------
79
+ class GPT2Transformer(nn.Module):
80
+ def __init__(self, config):
81
+ super().__init__()
82
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
83
+ self.wpe = nn.Embedding(config.n_ctx, config.n_embd)
84
+
85
+ self.h = nn.ModuleList([GPT2Block(config) for _ in range(config.n_layer)])
86
+ self.ln_f = nn.LayerNorm(config.n_embd, eps=1e-5)
87
+
88
+ def forward(self, input_ids):
89
+ B, T = input_ids.size()
90
+ pos = torch.arange(T, device=input_ids.device).unsqueeze(0)
91
+
92
+ x = self.wte(input_ids) + self.wpe(pos)
93
+ for block in self.h:
94
+ x = block(x)
95
+ return self.ln_f(x)
96
+
97
+ # Required by Hugging Face
98
+ def get_input_embeddings(self):
99
+ return self.wte
100
+
101
+ def set_input_embeddings(self, new_embeddings):
102
+ self.wte = new_embeddings
103
+
104
+ # -------------------------------------------------
105
+ # GPT-2 LM Head (HF Compatible)
106
+ # -------------------------------------------------
107
+ class GPT2LMHeadModel(PreTrainedModel):
108
+ config_class = GPT2Config
109
+ base_model_prefix = "transformer"
110
+
111
+ def __init__(self, config):
112
+ super().__init__(config)
113
+
114
+ self.transformer = GPT2Transformer(config)
115
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
116
+
117
+ # weight tying
118
+ self.lm_head.weight = self.transformer.wte.weight
119
+
120
+ self.post_init()
121
+
122
+ # Required by Hugging Face
123
+ def get_input_embeddings(self):
124
+ return self.transformer.wte
125
+
126
+ def set_input_embeddings(self, new_embeddings):
127
+ self.transformer.wte = new_embeddings
128
+
129
+ def get_output_embeddings(self):
130
+ return self.lm_head
131
+
132
+ def forward(self, input_ids, labels=None):
133
+ hidden_states = self.transformer(input_ids)
134
+ logits = self.lm_head(hidden_states)
135
+
136
+ loss = None
137
+ if labels is not None:
138
+ loss = F.cross_entropy(
139
+ logits.view(-1, logits.size(-1)),
140
+ labels.view(-1)
141
+ )
142
+
143
+ return CausalLMOutputWithCrossAttentions(
144
+ loss=loss,
145
+ logits=logits
146
+ )