klemenk commited on
Commit
9241475
·
verified ·
1 Parent(s): e37991c

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +206 -0
model.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ from transformers import PreTrainedModel
6
+ from .config import GPTConfig
7
+
8
+
9
+ ################################
10
+ ### Layers ###
11
+ ################################
12
+
13
+ class Rotary(torch.nn.Module):
14
+
15
+ def __init__(self, dim, base=10000):
16
+ super().__init__()
17
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
18
+ self.register_buffer("inv_freq", inv_freq)
19
+ self.seq_len_cached = None
20
+ self.cos_cached = None
21
+ self.sin_cached = None
22
+
23
+ def forward(self, x):
24
+ seq_len = x.shape[1]
25
+ if seq_len != self.seq_len_cached:
26
+ self.seq_len_cached = seq_len
27
+ t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
28
+ freqs = torch.outer(t, self.inv_freq).to(x.device)
29
+ self.cos_cached = freqs.cos()
30
+ self.sin_cached = freqs.sin()
31
+ return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]
32
+
33
+ def apply_rotary_emb(x, cos, sin):
34
+ assert x.ndim == 4 # multihead attention
35
+ d = x.shape[3]//2
36
+ x1 = x[..., :d]
37
+ x2 = x[..., d:]
38
+ y1 = x1 * cos + x2 * sin
39
+ y2 = x1 * (-sin) + x2 * cos
40
+ return torch.cat([y1, y2], 3)
41
+
42
+ def rmsnorm(x0, eps=1e-6):
43
+ x = x0.float()
44
+ x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
45
+ return x.type_as(x0)
46
+
47
+
48
+ class RMSNorm(nn.Module):
49
+ """ Root Mean Square Normalization """
50
+ def __init__(self, dim: int, weight: bool = False, bias: bool = False, eps: float = 1e-6):
51
+ super().__init__()
52
+ self.eps = eps
53
+
54
+ if weight:
55
+ self.weight = nn.Parameter(torch.ones(dim))
56
+ else:
57
+ self.register_parameter("weight", None)
58
+
59
+ if bias:
60
+ self.bias = nn.Parameter(torch.zeros(dim))
61
+ else:
62
+ self.register_parameter("bias", None)
63
+
64
+ def _norm(self, x):
65
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
66
+
67
+ def forward(self, x):
68
+ output = self._norm(x.float()).type_as(x)
69
+ if self.weight is not None:
70
+ output = output * self.weight
71
+ if self.bias is not None:
72
+ output = output + self.bias
73
+ return output
74
+
75
+
76
+ class CausalSelfAttention(nn.Module):
77
+
78
+ def __init__(self, config):
79
+ super().__init__()
80
+ self.n_head = config.n_head
81
+ self.n_embd = config.n_embd
82
+ self.head_dim = self.n_embd // self.n_head
83
+ assert self.n_embd % self.n_head == 0
84
+ # key, query, value projections for all heads, but in a batch
85
+ self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=False)
86
+ # output projection
87
+ self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
88
+ self.rotary = Rotary(self.head_dim)
89
+
90
+ def forward(self, x):
91
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
92
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
93
+ qkv = self.c_attn(x)
94
+ q, k, v = qkv.split(self.n_embd, dim=2)
95
+ k = k.view(B, T, self.n_head, self.head_dim)
96
+ q = q.view(B, T, self.n_head, self.head_dim)
97
+ v = v.view(B, T, self.n_head, self.head_dim)
98
+ cos, sin = self.rotary(q)
99
+ q = apply_rotary_emb(q, cos, sin)
100
+ k = apply_rotary_emb(k, cos, sin)
101
+ y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True)
102
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
103
+ # output projection
104
+ y = self.c_proj(y)
105
+ return y
106
+
107
+ class RMSNorm(nn.Module):
108
+ def __init__(self, dim, eps=1e-5):
109
+ super().__init__()
110
+ self.eps = eps
111
+ self.weight = nn.Parameter(torch.ones(dim))
112
+
113
+ def forward(self, x):
114
+ norm = torch.norm(x, dim=-1, keepdim=True)
115
+ return self.weight * x / (norm + self.eps)
116
+
117
+ class Block(nn.Module):
118
+
119
+ def __init__(self, config):
120
+ super().__init__()
121
+ self.attn = CausalSelfAttention(config)
122
+ self.mlp = MLP(config)
123
+ self.attn_scale = (1 / (2 * config.n_layer)**0.5)
124
+
125
+ def forward(self, x):
126
+ x = x + self.attn_scale * self.attn(rmsnorm(x))
127
+ x = x + self.mlp(rmsnorm(x))
128
+ return x
129
+
130
+ class MLP(nn.Module):
131
+
132
+ def __init__(self, config):
133
+ super().__init__()
134
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
135
+ self.gelu = nn.GELU()
136
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
137
+ self.dropout = nn.Dropout(config.dropout)
138
+
139
+ def forward(self, x):
140
+ x = self.c_fc(x)
141
+ x = self.gelu(x)
142
+ x = self.c_proj(x)
143
+ x = self.dropout(x)
144
+ return x
145
+
146
+
147
+ ################################
148
+ ### Model ###
149
+ ################################
150
+
151
+ class GPT(PreTrainedModel):
152
+ config_class = GPTConfig
153
+
154
+ def __init__(self, config):
155
+ super().__init__(config)
156
+ self.transformer = nn.ModuleDict(dict(
157
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
158
+ drop=nn.Dropout(config.dropout),
159
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
160
+ ))
161
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
162
+
163
+ self.apply(self._init_weights)
164
+
165
+ # GPT-2 style scaled init
166
+ for pn, p in self.named_parameters():
167
+ if pn.endswith('c_proj.weight'):
168
+ torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
169
+
170
+ def _init_weights(self, module):
171
+ if isinstance(module, nn.Linear):
172
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
173
+ if module.bias is not None:
174
+ torch.nn.init.zeros_(module.bias)
175
+ elif isinstance(module, nn.Embedding):
176
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
177
+
178
+ def forward(self, input_ids, labels=None):
179
+ tok_emb = self.transformer.wte(input_ids)
180
+ x = self.transformer.drop(tok_emb)
181
+
182
+ for block in self.transformer.h:
183
+ x = block(x)
184
+ x = rmsnorm(x)
185
+
186
+ logits = self.lm_head(x)
187
+
188
+ loss = None
189
+ if labels is not None:
190
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-1)
191
+
192
+ return {'loss': loss, 'logits': logits} if loss is not None else {'logits': logits}
193
+
194
+ @torch.no_grad()
195
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
196
+ for _ in range(max_new_tokens):
197
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
198
+ logits = self(idx_cond)['logits']
199
+ logits = logits[:, -1, :] / temperature
200
+ if top_k is not None:
201
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
202
+ logits[logits < v[:, [-1]]] = -float('Inf')
203
+ probs = F.softmax(logits, dim=-1)
204
+ idx_next = torch.multinomial(probs, num_samples=1)
205
+ idx = torch.cat((idx, idx_next), dim=1)
206
+ return idx