armenjeddi commited on
Commit
0032fe7
·
verified ·
1 Parent(s): e2017b3

Add tmlt model with 3 layers

Browse files
config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "TMLTGPTForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "modeling_tmlt.GPTConfig",
7
+ "AutoModelForCausalLM": "modeling_tmlt.TMLTGPTForCausalLM"
8
+ },
9
+ "dtype": "bfloat16",
10
+ "model_type": "tmlt",
11
+ "transformers_version": "4.57.0"
12
+ }
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.57.0"
4
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ce799d14f7d0243078dc932ed0d887d26be3d18f6dcef51acdd112307e03b07
3
+ size 546896752
modeling_tmlt.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers import PreTrainedModel, PretrainedConfig
7
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
8
+
9
+ class CausalSelfAttention(nn.Module):
10
+
11
+ def __init__(self, config):
12
+ super().__init__()
13
+ assert config.n_embd % config.n_head == 0
14
+ # key, query, value projections for all heads, but in a batch
15
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
16
+ # output projection
17
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
18
+ # regularization
19
+ self.attn_dropout = nn.Dropout(config.dropout)
20
+ self.resid_dropout = nn.Dropout(config.dropout)
21
+ self.n_head = config.n_head
22
+ self.n_embd = config.n_embd
23
+ self.dropout = config.dropout
24
+ # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
25
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
26
+ if not self.flash:
27
+ print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
28
+ # causal mask to ensure that attention is only applied to the left in the input sequence
29
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
30
+ .view(1, 1, config.block_size, config.block_size))
31
+
32
+ def forward(self, x):
33
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
34
+
35
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
36
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
37
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
38
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
39
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
40
+
41
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
42
+ if self.flash:
43
+ # efficient attention using Flash Attention CUDA kernels
44
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
45
+ else:
46
+ # manual implementation of attention
47
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
48
+ att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
49
+ att = F.softmax(att, dim=-1)
50
+ att = self.attn_dropout(att)
51
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
52
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
53
+
54
+ # output projection
55
+ y = self.resid_dropout(self.c_proj(y))
56
+ return y
57
+
58
+ class MLP(nn.Module):
59
+
60
+ def __init__(self, config):
61
+ super().__init__()
62
+ self.c_fc = nn.Linear(config.n_embd, config.intermediate_dim, bias=config.bias)
63
+ self.gelu = nn.GELU()
64
+ self.c_proj = nn.Linear(config.intermediate_dim, config.n_embd, bias=config.bias)
65
+ self.dropout = nn.Dropout(config.dropout)
66
+
67
+ def forward(self, x):
68
+ x = self.c_fc(x)
69
+ x = self.gelu(x)
70
+ x = self.c_proj(x)
71
+ x = self.dropout(x)
72
+ return x
73
+
74
+ class LoopFormerBlock(nn.Module):
75
+
76
+ def __init__(self, config):
77
+ super().__init__()
78
+ self.norm_1 = nn.RMSNorm(config.n_embd, elementwise_affine=False)
79
+ self.attn = CausalSelfAttention(config)
80
+ self.norm_2 = nn.RMSNorm(config.n_embd, elementwise_affine=False)
81
+ self.mlp = MLP(config)
82
+
83
+ self.adaLN_modulation = nn.Sequential(
84
+ nn.SiLU(),
85
+ nn.Linear(config.n_embd, 4 * config.n_embd, bias=True),
86
+ )
87
+
88
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
89
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
90
+
91
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
92
+ gate_msa, gate_mlp, scale_msa, scale_mlp = self.adaLN_modulation(c).chunk(4, dim=1)
93
+
94
+ x = x + gate_msa.unsqueeze(1) * self.attn(
95
+ self.norm_1(x) * (1 + scale_msa.unsqueeze(1))
96
+ )
97
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(
98
+ self.norm_2(x) * (1 + scale_mlp.unsqueeze(1))
99
+ )
100
+ return x
101
+
102
+ class TimestepEmbedder(nn.Module):
103
+ def __init__(self, hidden_size, frequency_embedding_size=256):
104
+ super().__init__()
105
+ self.mlp = nn.Sequential(
106
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
107
+ nn.SiLU(),
108
+ nn.Linear(hidden_size, hidden_size, bias=True),
109
+ )
110
+ self.frequency_embedding_size = frequency_embedding_size
111
+
112
+ @staticmethod
113
+ def timestep_embedding(t, dim, max_period=10000):
114
+ half = dim // 2
115
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
116
+ device=t.device
117
+ )
118
+ args = t[:, None].float() * freqs[None]
119
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
120
+ if dim % 2:
121
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
122
+ return embedding
123
+
124
+ def forward(self, t):
125
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
126
+ t_freq = t_freq.to(dtype=self.mlp[0].weight.dtype)
127
+ t_emb = self.mlp(t_freq)
128
+ return t_emb
129
+
130
+ class SharedBlock(nn.Module):
131
+ def __init__(self, depth, config):
132
+ super().__init__()
133
+ self.blocks = nn.ModuleList([
134
+ LoopFormerBlock(config) for _ in range(depth)
135
+ ])
136
+
137
+ def forward(self, x, c):
138
+ for block in self.blocks:
139
+ x = block(x, c)
140
+ return x
141
+
142
+ @dataclass
143
+ class GPTConfig(PretrainedConfig):
144
+ model_type: str = 'tmlt'
145
+ block_size: int = 1024
146
+ vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
147
+ n_layer: int = 3
148
+ n_head: int = 32
149
+ n_embd: int = 2048
150
+ dropout: float = 0.0
151
+ bias: bool = False # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
152
+ intermediate_dim: int = 5120
153
+
154
+ def __init__(self, **kwargs):
155
+ super().__init__(**kwargs)
156
+
157
+ class GPT(nn.Module):
158
+
159
+ def __init__(self, config):
160
+ super().__init__()
161
+ assert config.vocab_size is not None
162
+ assert config.block_size is not None
163
+ self.config = config
164
+
165
+ self.transformer = nn.ModuleDict(dict(
166
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
167
+ wpe = nn.Embedding(config.block_size, config.n_embd),
168
+ drop = nn.Dropout(config.dropout),
169
+ h = SharedBlock(config.n_layer, config),
170
+ norm_f = nn.RMSNorm(config.n_embd),
171
+ ))
172
+
173
+ self.time_embedder = TimestepEmbedder(config.n_embd)
174
+
175
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
176
+ # with weight tying when using torch.compile() some warnings get generated:
177
+ # "UserWarning: functional_call was passed multiple values for tied weights.
178
+ # This behavior is deprecated and will be an error in future versions"
179
+ # not 100% sure what this is, so far seems to be harmless. TODO investigate
180
+ self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
181
+
182
+ # init all weights
183
+ self.apply(self._init_weights)
184
+ # apply special scaled init to the residual projections, per GPT-2 paper
185
+ for pn, p in self.named_parameters():
186
+ if pn.endswith('c_proj.weight'):
187
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
188
+
189
+ # report number of parameters
190
+ print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
191
+
192
+ def get_num_params(self, non_embedding=True):
193
+ """
194
+ Return the number of parameters in the model.
195
+ For non-embedding count (default), the position embeddings get subtracted.
196
+ The token embeddings would too, except due to the parameter sharing these
197
+ params are actually used as weights in the final layer, so we include them.
198
+ """
199
+ n_params = sum(p.numel() for p in self.parameters())
200
+ if non_embedding:
201
+ n_params -= self.transformer.wpe.weight.numel()
202
+ return n_params
203
+
204
+ def _init_weights(self, module):
205
+ if isinstance(module, nn.Linear):
206
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
207
+ if module.bias is not None:
208
+ torch.nn.init.zeros_(module.bias)
209
+ elif isinstance(module, nn.Embedding):
210
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
211
+
212
+ def forward(self, idx, targets=None, steps=8, **kwargs):
213
+ device = idx.device
214
+ b, t = idx.size()
215
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
216
+ pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
217
+
218
+ # forward the GPT model itself
219
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
220
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
221
+ x = self.transformer.drop(tok_emb + pos_emb)
222
+
223
+ for step_i in range(steps):
224
+ ti = torch.full((b,), step_i, dtype=torch.long, device=x.device)
225
+ te = self.time_embedder(ti)
226
+ x = self.transformer.h(x, te)
227
+
228
+ x = self.transformer.norm_f(x)
229
+
230
+ logits = self.lm_head(x)
231
+
232
+ loss = None
233
+ if targets is not None:
234
+ loss = F.cross_entropy(
235
+ logits.view(-1, logits.size(-1)),
236
+ targets.view(-1),
237
+ ignore_index=-1,
238
+ )
239
+
240
+ return logits, loss
241
+
242
+
243
+ # ---- HF wrapper -------------------------------------------------------------
244
+
245
+ from transformers.generation.utils import GenerationMixin
246
+
247
+ class TMLTGPTForCausalLM(PreTrainedModel, GenerationMixin):
248
+ config_class = GPTConfig
249
+ main_input_name = "input_ids"
250
+ _tied_weights_keys = ["gpt.transformer.wte.weight", "gpt.lm_head.weight"]
251
+
252
+ def __init__(self, config: GPTConfig, **kwargs):
253
+ super().__init__(config)
254
+ self.gpt = GPT(config)
255
+ self.post_init()
256
+
257
+ # expose embeddings/heads for HF utilities
258
+ def get_input_embeddings(self):
259
+ return self.gpt.transformer.wte
260
+
261
+ def set_input_embeddings(self, new_emb):
262
+ self.gpt.transformer.wte = new_emb
263
+ self.gpt.lm_head.weight = new_emb.weight # keep tied
264
+
265
+ def get_output_embeddings(self):
266
+ return self.gpt.lm_head
267
+
268
+ def set_output_embeddings(self, new_out):
269
+ self.gpt.lm_head = new_out
270
+
271
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs):
272
+ return {
273
+ "input_ids": input_ids,
274
+ "attention_mask": attention_mask,
275
+ # no labels during generation
276
+ }
277
+
278
+ def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
279
+ # pass the mask all the way through
280
+ logits, loss = self.gpt(
281
+ input_ids, targets=labels, attention_mask=attention_mask
282
+ )
283
+ return CausalLMOutputWithCrossAttentions(loss=loss, logits=logits)
284
+
special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|endoftext|>",
3
+ "eos_token": "<|endoftext|>",
4
+ "pad_token": "<|endoftext|>",
5
+ "unk_token": "<|endoftext|>"
6
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "50256": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ }
12
+ },
13
+ "bos_token": "<|endoftext|>",
14
+ "clean_up_tokenization_spaces": false,
15
+ "eos_token": "<|endoftext|>",
16
+ "extra_special_tokens": {},
17
+ "model_max_length": 1024,
18
+ "pad_token": "<|endoftext|>",
19
+ "tokenizer_class": "GPT2Tokenizer",
20
+ "unk_token": "<|endoftext|>"
21
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff