armenjeddi commited on
Commit
2b4fbaa
·
verified ·
1 Parent(s): c4d3f09

Add LoopFormer model with 3 layers - max 8 iterations

Browse files
config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LoopFormerGPTForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "modeling_loopformer.GPTConfig",
7
+ "AutoModelForCausalLM": "modeling_loopformer.LoopFormerGPTForCausalLM"
8
+ },
9
+ "dtype": "bfloat16",
10
+ "model_type": "loopformer",
11
+ "transformers_version": "4.57.0"
12
+ }
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.57.0"
4
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:077fae449dd3af29d0313412ef50036e1d9a37fb8b11d447fbb41da0462d9185
3
+ size 556342528
modeling_loopformer.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ from dataclasses import dataclass
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from transformers import PreTrainedModel, PretrainedConfig
8
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
9
+
10
+ class CausalSelfAttention(nn.Module):
11
+
12
+ def __init__(self, config):
13
+ super().__init__()
14
+ assert config.n_embd % config.n_head == 0
15
+ # key, query, value projections for all heads, but in a batch
16
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
17
+ # output projection
18
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
19
+ # regularization
20
+ self.attn_dropout = nn.Dropout(config.dropout)
21
+ self.resid_dropout = nn.Dropout(config.dropout)
22
+ self.n_head = config.n_head
23
+ self.n_embd = config.n_embd
24
+ self.dropout = config.dropout
25
+ # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
26
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
27
+ if not self.flash:
28
+ print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
29
+ # causal mask to ensure that attention is only applied to the left in the input sequence
30
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
31
+ .view(1, 1, config.block_size, config.block_size))
32
+
33
+ def forward(self, x):
34
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
35
+
36
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
37
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
38
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
39
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
40
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
41
+
42
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
43
+ if self.flash:
44
+ # efficient attention using Flash Attention CUDA kernels
45
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
46
+ else:
47
+ # manual implementation of attention
48
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
49
+ att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
50
+ att = F.softmax(att, dim=-1)
51
+ att = self.attn_dropout(att)
52
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
53
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
54
+
55
+ # output projection
56
+ y = self.resid_dropout(self.c_proj(y))
57
+ return y
58
+
59
+ class MLP(nn.Module):
60
+
61
+ def __init__(self, config):
62
+ super().__init__()
63
+ self.c_fc = nn.Linear(config.n_embd, config.intermediate_dim, bias=config.bias)
64
+ self.gelu = nn.GELU()
65
+ self.c_proj = nn.Linear(config.intermediate_dim, config.n_embd, bias=config.bias)
66
+ self.dropout = nn.Dropout(config.dropout)
67
+
68
+ def forward(self, x):
69
+ x = self.c_fc(x)
70
+ x = self.gelu(x)
71
+ x = self.c_proj(x)
72
+ x = self.dropout(x)
73
+ return x
74
+
75
+ class LoopFormerBlock(nn.Module):
76
+
77
+ def __init__(self, config):
78
+ super().__init__()
79
+ self.norm_1 = nn.RMSNorm(config.n_embd, elementwise_affine=False)
80
+ self.attn = CausalSelfAttention(config)
81
+ self.norm_2 = nn.RMSNorm(config.n_embd, elementwise_affine=False)
82
+ self.mlp = MLP(config)
83
+
84
+ self.adaLN_modulation = nn.Sequential(
85
+ nn.SiLU(),
86
+ nn.Linear(config.n_embd, 4 * config.n_embd, bias=True),
87
+ )
88
+
89
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
90
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
91
+
92
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
93
+ gate_msa, gate_mlp, scale_msa, scale_mlp = self.adaLN_modulation(c).chunk(4, dim=1)
94
+
95
+ x = x + gate_msa.unsqueeze(1) * self.attn(
96
+ self.norm_1(x) * (1 + scale_msa.unsqueeze(1))
97
+ )
98
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(
99
+ self.norm_2(x) * (1 + scale_mlp.unsqueeze(1))
100
+ )
101
+ return x
102
+
103
+ class TimestepEmbedder(nn.Module):
104
+ def __init__(self, hidden_size, frequency_embedding_size=256):
105
+ super().__init__()
106
+ self.mlp = nn.Sequential(
107
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
108
+ nn.SiLU(),
109
+ nn.Linear(hidden_size, hidden_size, bias=True),
110
+ )
111
+ self.frequency_embedding_size = frequency_embedding_size
112
+
113
+ @staticmethod
114
+ def timestep_embedding(t, dim, max_period=10000):
115
+ half = dim // 2
116
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
117
+ device=t.device
118
+ )
119
+ args = t[:, None].float() * freqs[None]
120
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
121
+ if dim % 2:
122
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
123
+ return embedding
124
+
125
+ def forward(self, t):
126
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
127
+ t_freq = t_freq.to(dtype=self.mlp[0].weight.dtype)
128
+ t_emb = self.mlp(t_freq)
129
+ return t_emb
130
+
131
+ class SharedBlock(nn.Module):
132
+ def __init__(self, depth, config):
133
+ super().__init__()
134
+ self.blocks = nn.ModuleList([
135
+ LoopFormerBlock(config) for _ in range(depth)
136
+ ])
137
+
138
+ def forward(self, x, c):
139
+ for block in self.blocks:
140
+ x = block(x, c)
141
+ return x
142
+
143
+ @dataclass
144
+ class GPTConfig(PretrainedConfig):
145
+ model_type: str = 'loopformer'
146
+ block_size: int = 1024
147
+ vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
148
+ n_layer: int = 3
149
+ n_head: int = 32
150
+ n_embd: int = 2048
151
+ dropout: float = 0.0
152
+ bias: bool = False # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
153
+ intermediate_dim: int = 5120
154
+
155
+ def __init__(self, **kwargs):
156
+ super().__init__(**kwargs)
157
+
158
+ class GPT(nn.Module):
159
+
160
+ def __init__(self, config):
161
+ super().__init__()
162
+ assert config.vocab_size is not None
163
+ assert config.block_size is not None
164
+ self.config = config
165
+
166
+ self.transformer = nn.ModuleDict(dict(
167
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
168
+ wpe = nn.Embedding(config.block_size, config.n_embd),
169
+ drop = nn.Dropout(config.dropout),
170
+ h = SharedBlock(config.n_layer, config),
171
+ norm_f = nn.RMSNorm(config.n_embd),
172
+ ))
173
+
174
+ self.time_embedder = TimestepEmbedder(config.n_embd)
175
+ self.dt_embedder = TimestepEmbedder(config.n_embd)
176
+
177
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
178
+ # with weight tying when using torch.compile() some warnings get generated:
179
+ # "UserWarning: functional_call was passed multiple values for tied weights.
180
+ # This behavior is deprecated and will be an error in future versions"
181
+ # not 100% sure what this is, so far seems to be harmless. TODO investigate
182
+ self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
183
+
184
+ # init all weights
185
+ self.apply(self._init_weights)
186
+ # apply special scaled init to the residual projections, per GPT-2 paper
187
+ for pn, p in self.named_parameters():
188
+ if pn.endswith('c_proj.weight'):
189
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
190
+
191
+ # report number of parameters
192
+ print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
193
+
194
+ def get_num_params(self, non_embedding=True):
195
+ """
196
+ Return the number of parameters in the model.
197
+ For non-embedding count (default), the position embeddings get subtracted.
198
+ The token embeddings would too, except due to the parameter sharing these
199
+ params are actually used as weights in the final layer, so we include them.
200
+ """
201
+ n_params = sum(p.numel() for p in self.parameters())
202
+ if non_embedding:
203
+ n_params -= self.transformer.wpe.weight.numel()
204
+ return n_params
205
+
206
+ def _init_weights(self, module):
207
+ if isinstance(module, nn.Linear):
208
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
209
+ if module.bias is not None:
210
+ torch.nn.init.zeros_(module.bias)
211
+ elif isinstance(module, nn.Embedding):
212
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
213
+
214
+ def forward(self, idx, targets=None, steps=[1/8]*8, **kwargs):
215
+ device = idx.device
216
+ b, t = idx.size()
217
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
218
+ pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
219
+
220
+ # forward the GPT model itself
221
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
222
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
223
+ x = self.transformer.drop(tok_emb + pos_emb)
224
+
225
+ ti = torch.zeros(x.shape[0], dtype=x.dtype).to(x.device)
226
+ for dt in steps:
227
+ dt_base = torch.ones_like(ti) * dt
228
+ te = self.time_embedder(ti)
229
+ dte = self.dt_embedder(dt_base)
230
+ c = te + dte
231
+ x = self.transformer.h(x, c)
232
+ ti = ti + dt
233
+
234
+ x = self.transformer.norm_f(x)
235
+
236
+ logits = self.lm_head(x)
237
+
238
+ loss = None
239
+ if targets is not None:
240
+ loss = F.cross_entropy(
241
+ logits.view(-1, logits.size(-1)),
242
+ targets.view(-1),
243
+ ignore_index=-1,
244
+ )
245
+
246
+ return logits, loss
247
+
248
+ # ---- HF wrapper -------------------------------------------------------------
249
+
250
+ from transformers.generation.utils import GenerationMixin
251
+
252
+ class LoopFormerGPTForCausalLM(PreTrainedModel, GenerationMixin):
253
+ config_class = GPTConfig
254
+ main_input_name = "input_ids"
255
+ _tied_weights_keys = ["gpt.transformer.wte.weight", "gpt.lm_head.weight"]
256
+
257
+ def __init__(self, config: GPTConfig, **kwargs):
258
+ super().__init__(config)
259
+ self.gpt = GPT(config)
260
+ self.post_init()
261
+
262
+ # expose embeddings/heads for HF utilities
263
+ def get_input_embeddings(self):
264
+ return self.gpt.transformer.wte
265
+
266
+ def set_input_embeddings(self, new_emb):
267
+ self.gpt.transformer.wte = new_emb
268
+ self.gpt.lm_head.weight = new_emb.weight # keep tied
269
+
270
+ def get_output_embeddings(self):
271
+ return self.gpt.lm_head
272
+
273
+ def set_output_embeddings(self, new_out):
274
+ self.gpt.lm_head = new_out
275
+
276
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, steps=None, **kwargs):
277
+ # Let HF build the usual inputs (esp. past_key_values, position_ids, etc.)
278
+ model_inputs = super().prepare_inputs_for_generation(
279
+ input_ids=input_ids,
280
+ attention_mask=attention_mask,
281
+ **kwargs
282
+ )
283
+ # Whitelist your custom arg so `generate()` won't complain
284
+ if steps is not None:
285
+ model_inputs["steps"] = steps
286
+ return model_inputs
287
+
288
+ def forward(self, input_ids=None, attention_mask=None, labels=None, steps=None, **kwargs):
289
+ # pick steps: explicit arg > kwargs > default
290
+ if steps is None:
291
+ steps = kwargs.pop("steps", [1/8]*8)
292
+
293
+ logits, loss = self.gpt(
294
+ input_ids, targets=labels, steps=steps, attention_mask=attention_mask
295
+ )
296
+ return CausalLMOutputWithCrossAttentions(loss=loss, logits=logits)
special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|endoftext|>",
3
+ "eos_token": "<|endoftext|>",
4
+ "pad_token": "<|endoftext|>",
5
+ "unk_token": "<|endoftext|>"
6
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "50256": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ }
12
+ },
13
+ "bos_token": "<|endoftext|>",
14
+ "clean_up_tokenization_spaces": false,
15
+ "eos_token": "<|endoftext|>",
16
+ "extra_special_tokens": {},
17
+ "model_max_length": 1024,
18
+ "pad_token": "<|endoftext|>",
19
+ "tokenizer_class": "GPT2Tokenizer",
20
+ "unk_token": "<|endoftext|>"
21
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff