williamconvertino commited on
Commit
6a2ee89
·
verified ·
1 Parent(s): 1f309d1

Upload folder using huggingface_hub

Browse files
config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "fst",
3
+ "_class_name": "FSTConfig",
4
+ "architectures": [
5
+ "FSTForCausalLM"
6
+ ],
7
+ "auto_map": {
8
+ "AutoConfig": "config.FSTConfig",
9
+ "AutoModel": "model.FSTModel",
10
+ "AutoModelForCausalLM": "model.FSTForCausalLM"
11
+ },
12
+
13
+ "vocab_size": 50257,
14
+ "hidden_size": 2048,
15
+ "num_hidden_layers": 24,
16
+ "num_attention_heads": 32,
17
+ "intermediate_size": 8192,
18
+ "max_position_embeddings": 2048,
19
+
20
+ "use_causal_attention": true,
21
+ "use_cache": false,
22
+
23
+ "initializer_range": 0.02,
24
+
25
+ "bos_token_id": 50256,
26
+ "eos_token_id": 50256,
27
+ "pad_token_id": 50256,
28
+
29
+ "transformers_version": "4.57.1"
30
+
31
+ }
config.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class FSTConfig(PretrainedConfig):
4
+ model_type = "fst"
5
+
6
+ def __init__(
7
+ self,
8
+
9
+ # Core
10
+ vocab_size: int = 50257,
11
+ hidden_size: int = 2048,
12
+ num_hidden_layers: int = 24,
13
+ num_attention_heads: int = 32,
14
+ intermediate_size: int = 8192,
15
+ max_position_embeddings: int = 2048,
16
+
17
+ # Attention
18
+ use_causal_attention: bool = True,
19
+ use_cache: bool = True, # Disable during training
20
+
21
+ # Initialization and Normalization
22
+ initializer_range: float = 0.02,
23
+
24
+ # Tokenizer
25
+ bos_token_id: int | None = None,
26
+ eos_token_id: int | None = None,
27
+ pad_token_id: int | None = None,
28
+
29
+ **kwargs,
30
+ ):
31
+ super().__init__(
32
+ bos_token_id=bos_token_id,
33
+ eos_token_id=eos_token_id,
34
+ pad_token_id=pad_token_id,
35
+ **kwargs,
36
+ )
37
+
38
+ # Core
39
+ self.vocab_size = vocab_size
40
+ self.hidden_size = hidden_size
41
+ self.num_hidden_layers = num_hidden_layers
42
+ self.num_attention_heads = num_attention_heads
43
+ self.intermediate_size = intermediate_size
44
+ self.max_position_embeddings = max_position_embeddings
45
+
46
+ # Attention
47
+ self.use_causal_attention = use_causal_attention
48
+ self.use_cache = use_cache
49
+
50
+ # Initialization and Normalization
51
+ self.initializer_range = initializer_range
generation_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 50256,
4
+ "eos_token_id": 50256,
5
+ "pad_token_id": 50256,
6
+ "do_sample": true,
7
+ "temperature": 0.3,
8
+ "top_p": 0.95,
9
+ "transformers_version": "4.57.1",
10
+ "use_cache": true
11
+ }
model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a36093520284da151a5cea8f1171b6d0f6d017e00fa7eed7985d0fa0b1e14eb5
3
+ size 4977398040
model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:601c31659cc8991731d7eec7a509d53e2ec434b0f4a531c3401d7d2ce7dc2755
3
+ size 268569024
model.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from transformers import PreTrainedModel, GenerationMixin
9
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, MaskedLMOutput
10
+ from transformers.cache_utils import Cache, DynamicCache
11
+
12
+ from rotary_embedding_torch import RotaryEmbedding
13
+ from .config import FSTConfig
14
+
15
+ # === Util ===
16
+
17
+ class Residual(nn.Module):
18
+ def __init__(self):
19
+ super().__init__()
20
+
21
+ def forward(self, x: Tensor, delta: Tensor):
22
+ return x + delta
23
+
24
+ # === MLP ===
25
+
26
+ class MLP(nn.Module):
27
+ def __init__(
28
+ self,
29
+ hidden_size: int,
30
+ intermediate_size: int
31
+ ):
32
+ super().__init__()
33
+
34
+ self.fc_up = nn.Linear(hidden_size, intermediate_size)
35
+ self.activation = nn.GELU()
36
+ self.fc_down = nn.Linear(intermediate_size, hidden_size)
37
+
38
+ def forward(self, x: Tensor):
39
+ return self.fc_down(self.activation(self.fc_up(x)))
40
+
41
+ # === Attention ===
42
+
43
+ class MHAttention(nn.Module):
44
+
45
+ def __init__(
46
+ self,
47
+ hidden_size: int,
48
+ num_attention_heads: int,
49
+ use_causal_attention: bool = True,
50
+ layer_idx: int | None = None
51
+ ):
52
+ super().__init__()
53
+
54
+ self.hidden_size = hidden_size
55
+ self.num_attention_heads = num_attention_heads
56
+ self.head_dim = hidden_size // num_attention_heads
57
+
58
+ assert self.head_dim * self.num_attention_heads == self.hidden_size
59
+
60
+ self.use_causal_attention = use_causal_attention
61
+ self.layer_idx = layer_idx
62
+
63
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
64
+ self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
65
+ self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
66
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
67
+
68
+ self.rotary_emb = RotaryEmbedding(dim=self.head_dim)
69
+ self.scale = self.head_dim ** -0.5
70
+
71
+ def forward(
72
+ self,
73
+ q: Tensor,
74
+ k: Tensor | None = None,
75
+ v: Tensor | None = None,
76
+ attention_mask: Tensor | None = None,
77
+ past_key_values: Cache | None = None
78
+ ):
79
+ B, T, _ = q.size()
80
+
81
+ if k is None:
82
+ k = q
83
+ if v is None:
84
+ v = q
85
+
86
+ q = self.q_proj(q)
87
+ k = self.k_proj(k)
88
+ v = self.v_proj(v)
89
+
90
+ q = q.view(B, T, self.num_attention_heads, self.head_dim).transpose(1, 2)
91
+ k = k.view(B, T, self.num_attention_heads, self.head_dim).transpose(1, 2)
92
+ v = v.view(B, T, self.num_attention_heads, self.head_dim).transpose(1, 2)
93
+
94
+ if past_key_values is None:
95
+
96
+ q = self.rotary_emb.rotate_queries_or_keys(q)
97
+ k = self.rotary_emb.rotate_queries_or_keys(k)
98
+
99
+ else:
100
+
101
+ cache_position = past_key_values.get_seq_length(self.layer_idx)
102
+
103
+ q = self.rotary_emb.rotate_queries_or_keys(q, offset=cache_position)
104
+ k = self.rotary_emb.rotate_queries_or_keys(k, offset=cache_position)
105
+
106
+ k, v = past_key_values.update(k, v, self.layer_idx)
107
+
108
+ is_causal = self.use_causal_attention and attention_mask is None
109
+ attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask, scale=self.scale, is_causal=is_causal)
110
+
111
+ attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, self.hidden_size)
112
+ out = self.o_proj(attn_output)
113
+
114
+ return out
115
+
116
+ # === Blocks ===
117
+
118
+ class FeatureBlock(nn.Module):
119
+
120
+ def __init__(
121
+ self,
122
+ config: FSTConfig,
123
+ layer_idx: int = None
124
+ ):
125
+ super().__init__()
126
+
127
+ self.attn = MHAttention(
128
+ hidden_size=config.hidden_size,
129
+ num_attention_heads=config.num_attention_heads,
130
+ use_causal_attention=config.use_causal_attention,
131
+ layer_idx=layer_idx,
132
+ )
133
+
134
+ self.mlp = MLP(
135
+ config.hidden_size,
136
+ config.intermediate_size
137
+ )
138
+
139
+ self.norm_attn = nn.LayerNorm(config.hidden_size)
140
+ self.norm_mlp = nn.LayerNorm(config.hidden_size)
141
+
142
+ self.resid_attn = Residual()
143
+ self.resid_mlp = Residual()
144
+
145
+ def forward(
146
+ self,
147
+ x: Tensor,
148
+ attention_mask: Tensor | None = None,
149
+ past_key_values: Cache | None = None
150
+ ):
151
+
152
+ attn_out = self.attn(self.norm_attn(x), attention_mask=attention_mask, past_key_values=past_key_values)
153
+ x = self.resid_attn(x, attn_out)
154
+
155
+ mlp_out = self.mlp(self.norm_mlp(x))
156
+ x = self.resid_mlp(x, mlp_out)
157
+
158
+ return x
159
+
160
+ class PredictiveBlock(nn.Module):
161
+
162
+ def __init__(
163
+ self,
164
+ config: FSTConfig,
165
+ layer_idx: int = None
166
+ ):
167
+ super().__init__()
168
+
169
+ self.attn = MHAttention(
170
+ hidden_size=config.hidden_size,
171
+ num_attention_heads=config.num_attention_heads,
172
+ use_causal_attention=config.use_causal_attention,
173
+ layer_idx=layer_idx,
174
+ )
175
+
176
+ self.mlp = MLP(
177
+ config.hidden_size,
178
+ config.intermediate_size
179
+ )
180
+
181
+ self.norm_attn_qk = nn.LayerNorm(config.hidden_size)
182
+ self.norm_attn_v = nn.LayerNorm(config.hidden_size)
183
+ self.norm_mlp = nn.LayerNorm(config.hidden_size)
184
+
185
+ self.resid_attn = Residual()
186
+ self.resid_mlp = Residual()
187
+
188
+ def forward(
189
+ self,
190
+ phi: Tensor,
191
+ f: Tensor,
192
+ e: Tensor,
193
+ attention_mask: Tensor | None = None,
194
+ past_key_values: Cache | None = None
195
+ ):
196
+
197
+ qk = self.norm_attn_qk(phi)
198
+ v = self.norm_attn_v(e)
199
+
200
+ attn_out = self.attn(qk, qk, v, attention_mask=attention_mask, past_key_values=past_key_values)
201
+ f = self.resid_attn(f, attn_out)
202
+
203
+ mlp_out = self.mlp(self.norm_mlp(f))
204
+ f = self.resid_mlp(f, mlp_out)
205
+
206
+ return f
207
+
208
+ # === Base Model ===
209
+
210
+ class FSTPreTrainedModel(PreTrainedModel):
211
+
212
+ config_class = FSTConfig
213
+ base_model_prefix = "model"
214
+ _no_split_modules = ["FSTBlock"]
215
+ _skip_keys_device_placement = ["past_key_values"]
216
+ _supports_flash_attn_2 = True
217
+ _supports_cache_class = True
218
+
219
+ # Initialization taken from Deepseek and Falcon
220
+ def _init_weights(self, module):
221
+ std = self.config.initializer_range
222
+ if isinstance(module, nn.Linear):
223
+ module.weight.data.normal_(mean=0.0, std=std)
224
+ if module.bias is not None:
225
+ module.bias.data.zero_()
226
+ elif isinstance(module, nn.Embedding):
227
+ module.weight.data.normal_(mean=0.0, std=std)
228
+ if module.padding_idx is not None:
229
+ module.weight.data[module.padding_idx].zero_()
230
+
231
+ class FSTModel(FSTPreTrainedModel):
232
+
233
+ def __init__(
234
+ self,
235
+ config: FSTConfig
236
+ ):
237
+ super().__init__(config)
238
+
239
+ self.config = config
240
+ self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
241
+
242
+ self.feature_blocks = nn.ModuleList([FeatureBlock(config, layer_idx) for layer_idx in range(0, config.num_hidden_layers, 2)])
243
+ self.predictive_blocks = nn.ModuleList([PredictiveBlock(config, layer_idx) for layer_idx in range(1, config.num_hidden_layers, 2)])
244
+ self.norm_out = nn.LayerNorm(config.hidden_size)
245
+
246
+ self.post_init()
247
+
248
+ def _prepare_attention_mask(
249
+ self,
250
+ x: Tensor,
251
+ attention_mask: Tensor | None = None,
252
+ past_key_values: Cache | None = None,
253
+ use_causal_attention: bool = True
254
+ ):
255
+
256
+ device = x.device
257
+ B = x.shape[0]
258
+ T = x.shape[1]
259
+
260
+ T_past = past_key_values.get_seq_length() if past_key_values is not None else 0
261
+ T_total = T + T_past
262
+
263
+ if use_causal_attention:
264
+ causal_mask = ~torch.triu(
265
+ torch.ones((T, T_total), dtype=torch.bool, device=device),
266
+ diagonal=(1 + T_past)
267
+ ).unsqueeze(0).unsqueeze(0)
268
+
269
+ if attention_mask is not None:
270
+ attn_len = attention_mask.shape[-1]
271
+
272
+ if attn_len < T_total:
273
+ pad = torch.ones(B, T_past, device=device, dtype=attention_mask.dtype) # Fixed: ones instead of zeros
274
+ attention_mask = torch.cat([pad, attention_mask], dim=-1)
275
+ elif attn_len > T_total:
276
+ attention_mask = attention_mask[:, -T_total:]
277
+
278
+ expanded_mask = (attention_mask == 1).view(B, 1, 1, T_total)
279
+
280
+ if use_causal_attention and attention_mask is not None:
281
+ return causal_mask & expanded_mask
282
+ elif use_causal_attention:
283
+ return causal_mask
284
+ elif attention_mask is not None: # Added: handle non-causal with custom mask
285
+ return expanded_mask
286
+ else:
287
+ return torch.ones((1, 1, T, T_total), dtype=torch.bool, device=device)
288
+
289
+ def forward(
290
+ self,
291
+ input_ids: Tensor | None = None,
292
+ attention_mask: Tensor | None = None,
293
+ inputs_embeds: Tensor | None = None,
294
+ past_key_values = None,
295
+ use_cache: bool | None = None,
296
+ output_hidden_states: bool | None = None,
297
+ return_dict: bool | None = None,
298
+ **kwargs,
299
+ ):
300
+
301
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
302
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
303
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
304
+
305
+ assert not (input_ids is not None and inputs_embeds is not None), "You cannot specify both input_ids and inputs_embeds"
306
+ assert not (input_ids is None and inputs_embeds is None), "You must specify either input_ids or inputs_embeds"
307
+
308
+ e = self.embedding(input_ids) if input_ids is not None else inputs_embeds
309
+
310
+ B, T, _ = e.shape
311
+ device = e.device
312
+ dtype = e.dtype
313
+
314
+ if not use_cache:
315
+ past_key_values=None
316
+ elif past_key_values is None:
317
+ past_key_values = DynamicCache()
318
+
319
+ # Note that we must use an attention mask when caching- otherwise, SDPA uses is_casual and breaks
320
+ if attention_mask is not None or past_key_values is not None:
321
+ attention_mask = self._prepare_attention_mask(e, attention_mask=attention_mask, use_causal_attention=self.config.use_causal_attention, past_key_values=past_key_values)
322
+
323
+ hidden_states = [] if output_hidden_states else None
324
+
325
+ phi = e
326
+ f = torch.zeros(B, T, self.config.hidden_size, dtype=dtype, device=device) # Initialize f as zero for purity, but f=e also works fine
327
+
328
+ for feature_block, predictive_block in zip(self.feature_blocks, self.predictive_blocks):
329
+
330
+ phi = feature_block(phi, attention_mask=attention_mask, past_key_values=past_key_values)
331
+ f = predictive_block(phi, f, e, attention_mask=attention_mask, past_key_values=past_key_values)
332
+
333
+ if output_hidden_states:
334
+ hidden_states.append(phi)
335
+ hidden_states.append(f)
336
+
337
+ if hidden_states is not None:
338
+ hidden_states = tuple(hidden_states)
339
+
340
+ f = self.norm_out(f)
341
+
342
+ if return_dict:
343
+ return BaseModelOutputWithPast(
344
+ last_hidden_state=f,
345
+ past_key_values=past_key_values,
346
+ hidden_states=hidden_states
347
+ )
348
+
349
+ return f, past_key_values, hidden_states
350
+
351
+ # === Applied Models ===
352
+
353
+ class FSTForCausalLM(GenerationMixin, FSTPreTrainedModel):
354
+
355
+ accepts_loss_kwargs = False
356
+
357
+ def __init__(
358
+ self,
359
+ config: FSTConfig
360
+ ):
361
+ super().__init__(config)
362
+
363
+ self.model = FSTModel(config)
364
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
365
+
366
+ if config.tie_word_embeddings:
367
+ self.tie_weights()
368
+ self._dynamic_tied_weights_keys = {"lm_head.weight": "model.embedding.weight"} # Avoids safetensor naming issues
369
+
370
+ self.post_init()
371
+
372
+ def get_input_embeddings(self):
373
+ return self.model.embedding
374
+
375
+ def set_input_embeddings(self, new_embeddings):
376
+ self.model.embedding = new_embeddings
377
+
378
+ def get_output_embeddings(self):
379
+ return self.lm_head
380
+
381
+ def set_output_embeddings(self, new_embeddings):
382
+ self.lm_head = new_embeddings
383
+
384
+ def tie_weights(self):
385
+ self._tie_or_clone_weights(self.lm_head, self.get_input_embeddings())
386
+
387
+ def forward(
388
+ self,
389
+ input_ids: Tensor | None = None,
390
+ attention_mask: Tensor | None = None,
391
+ past_key_values = None,
392
+ inputs_embeds: Tensor | None = None,
393
+ labels: Tensor | None = None,
394
+ use_cache: bool | None = None,
395
+ output_hidden_states: bool | None = None,
396
+ return_dict: bool | None = None,
397
+ **kwargs,
398
+ ):
399
+
400
+ if labels is not None:
401
+ return_dict = True
402
+ else:
403
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
404
+
405
+ model_output = self.model(
406
+ input_ids=input_ids,
407
+ attention_mask=attention_mask,
408
+ inputs_embeds=inputs_embeds,
409
+ past_key_values=past_key_values,
410
+ use_cache=use_cache,
411
+ output_hidden_states=output_hidden_states
412
+ )
413
+
414
+ logits = self.lm_head(model_output[0])
415
+
416
+ loss = None
417
+ if labels is not None:
418
+ shift_logits = logits[:, :-1, :].contiguous()
419
+ shift_labels = labels[:, 1:].contiguous()
420
+ loss = F.cross_entropy(
421
+ shift_logits.view(-1, shift_logits.size(-1)),
422
+ shift_labels.view(-1),
423
+ ignore_index=self.config.pad_token_id if self.config.pad_token_id is not None else -100
424
+ )
425
+
426
+ if not return_dict:
427
+ output = (logits,) + model_output[1:]
428
+ return ((loss,) + output) if loss is not None else output
429
+
430
+ return CausalLMOutputWithPast(
431
+ loss=loss,
432
+ logits=logits,
433
+ past_key_values=model_output.past_key_values,
434
+ hidden_states=model_output.hidden_states
435
+ )
436
+
437
+ def _prepare_inputs_for_generation(
438
+ self,
439
+ input_ids: Tensor,
440
+ past_key_values: Cache | None = None,
441
+ attention_mask: Tensor | None = None,
442
+ **kwargs
443
+ ):
444
+ if past_key_values is not None:
445
+ input_ids = input_ids[:, -1:]
446
+
447
+ model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": True}
448
+
449
+ if attention_mask is not None:
450
+ model_inputs["attention_mask"] = attention_mask
451
+
452
+ for key, value in kwargs.items():
453
+ model_inputs[key] = value
454
+
455
+ return model_inputs
456
+
457
+ def _reorder_cache(self, past_key_values: Cache, beam_idx: Tensor):
458
+ return past_key_values.reorder_cache(beam_idx)
459
+
460
+ class FSTForMaskedLM(FSTPreTrainedModel):
461
+
462
+ accepts_loss_kwargs = False
463
+
464
+ def __init__(
465
+ self,
466
+ config: FSTConfig
467
+ ):
468
+ super().__init__(config)
469
+
470
+ assert not config.use_causal_attention, "FSTForMaskedLM requires use_causal_attention=False"
471
+ assert not config.use_cache, "FSTForMaskedLM requires use_cache=False (caching not supported for bidirectional models)"
472
+
473
+ self.model = FSTModel(config)
474
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
475
+
476
+ if config.tie_word_embeddings:
477
+ self.tie_weights()
478
+ self._dynamic_tied_weights_keys = {"lm_head.weight": "model.embedding.weight"} # Avoids safetensor naming issues
479
+
480
+ self.post_init()
481
+
482
+ def get_input_embeddings(self):
483
+ return self.model.embedding
484
+
485
+ def set_input_embeddings(self, new_embeddings):
486
+ self.model.embedding = new_embeddings
487
+
488
+ def get_output_embeddings(self):
489
+ return self.lm_head
490
+
491
+ def set_output_embeddings(self, new_embeddings):
492
+ self.lm_head = new_embeddings
493
+
494
+ def tie_weights(self):
495
+ self._tie_or_clone_weights(self.lm_head, self.get_input_embeddings())
496
+
497
+ def forward(
498
+ self,
499
+ input_ids: Tensor | None = None,
500
+ attention_mask: Tensor | None = None,
501
+ inputs_embeds: Tensor | None = None,
502
+ labels: Tensor | None = None,
503
+ output_hidden_states: bool | None = None,
504
+ return_dict: bool | None = None,
505
+ **kwargs,
506
+ ):
507
+
508
+ if labels is not None:
509
+ return_dict = True
510
+ else:
511
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
512
+
513
+ model_output = self.model(
514
+ input_ids=input_ids,
515
+ attention_mask=attention_mask,
516
+ inputs_embeds=inputs_embeds,
517
+ past_key_values=None,
518
+ use_cache=False,
519
+ output_hidden_states=output_hidden_states
520
+ )
521
+
522
+ logits = self.lm_head(model_output[0])
523
+
524
+ loss = None
525
+ if labels is not None:
526
+
527
+ loss = F.cross_entropy(
528
+ logits.view(-1, logits.size(-1)),
529
+ labels.view(-1),
530
+ ignore_index=self.config.pad_token_id if self.config.pad_token_id is not None else -100
531
+ )
532
+
533
+ if not return_dict:
534
+ output = (logits,) + model_output[1:]
535
+ return ((loss,) + output) if loss is not None else output
536
+
537
+ return MaskedLMOutput(
538
+ loss=loss,
539
+ logits=logits,
540
+ hidden_states=model_output.hidden_states
541
+ )
model.safetensors.index.json ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_parameters": 1311480576,
4
+ "total_size": 5245922304
5
+ },
6
+ "weight_map": {
7
+ "model.embedding.weight": "model-00001-of-00002.safetensors",
8
+ "model.predictive_blocks.0.attn.k_proj.weight": "model-00001-of-00002.safetensors",
9
+ "model.predictive_blocks.0.attn.o_proj.bias": "model-00001-of-00002.safetensors",
10
+ "model.predictive_blocks.0.attn.o_proj.weight": "model-00001-of-00002.safetensors",
11
+ "model.predictive_blocks.0.attn.q_proj.weight": "model-00001-of-00002.safetensors",
12
+ "model.predictive_blocks.0.attn.rotary_emb.freqs": "model-00001-of-00002.safetensors",
13
+ "model.predictive_blocks.0.attn.v_proj.bias": "model-00001-of-00002.safetensors",
14
+ "model.predictive_blocks.0.attn.v_proj.weight": "model-00001-of-00002.safetensors",
15
+ "model.predictive_blocks.0.norm_attn_qk.bias": "model-00001-of-00002.safetensors",
16
+ "model.predictive_blocks.0.norm_attn_qk.weight": "model-00001-of-00002.safetensors",
17
+ "model.predictive_blocks.0.norm_attn_v.bias": "model-00001-of-00002.safetensors",
18
+ "model.predictive_blocks.0.norm_attn_v.weight": "model-00001-of-00002.safetensors",
19
+ "model.predictive_blocks.0.norm_mlp.bias": "model-00001-of-00002.safetensors",
20
+ "model.predictive_blocks.0.norm_mlp.weight": "model-00001-of-00002.safetensors",
21
+ "model.predictive_blocks.0.mlp.fc_down.bias": "model-00001-of-00002.safetensors",
22
+ "model.predictive_blocks.0.mlp.fc_down.weight": "model-00001-of-00002.safetensors",
23
+ "model.predictive_blocks.0.mlp.fc_up.bias": "model-00001-of-00002.safetensors",
24
+ "model.predictive_blocks.0.mlp.fc_up.weight": "model-00001-of-00002.safetensors",
25
+ "model.predictive_blocks.1.attn.k_proj.weight": "model-00001-of-00002.safetensors",
26
+ "model.predictive_blocks.1.attn.o_proj.bias": "model-00001-of-00002.safetensors",
27
+ "model.predictive_blocks.1.attn.o_proj.weight": "model-00001-of-00002.safetensors",
28
+ "model.predictive_blocks.1.attn.q_proj.weight": "model-00001-of-00002.safetensors",
29
+ "model.predictive_blocks.1.attn.rotary_emb.freqs": "model-00001-of-00002.safetensors",
30
+ "model.predictive_blocks.1.attn.v_proj.bias": "model-00001-of-00002.safetensors",
31
+ "model.predictive_blocks.1.attn.v_proj.weight": "model-00001-of-00002.safetensors",
32
+ "model.predictive_blocks.1.norm_attn_qk.bias": "model-00001-of-00002.safetensors",
33
+ "model.predictive_blocks.1.norm_attn_qk.weight": "model-00001-of-00002.safetensors",
34
+ "model.predictive_blocks.1.norm_attn_v.bias": "model-00001-of-00002.safetensors",
35
+ "model.predictive_blocks.1.norm_attn_v.weight": "model-00001-of-00002.safetensors",
36
+ "model.predictive_blocks.1.norm_mlp.bias": "model-00001-of-00002.safetensors",
37
+ "model.predictive_blocks.1.norm_mlp.weight": "model-00001-of-00002.safetensors",
38
+ "model.predictive_blocks.1.mlp.fc_down.bias": "model-00001-of-00002.safetensors",
39
+ "model.predictive_blocks.1.mlp.fc_down.weight": "model-00001-of-00002.safetensors",
40
+ "model.predictive_blocks.1.mlp.fc_up.bias": "model-00001-of-00002.safetensors",
41
+ "model.predictive_blocks.1.mlp.fc_up.weight": "model-00001-of-00002.safetensors",
42
+ "model.predictive_blocks.10.attn.k_proj.weight": "model-00001-of-00002.safetensors",
43
+ "model.predictive_blocks.10.attn.o_proj.bias": "model-00001-of-00002.safetensors",
44
+ "model.predictive_blocks.10.attn.o_proj.weight": "model-00001-of-00002.safetensors",
45
+ "model.predictive_blocks.10.attn.q_proj.weight": "model-00001-of-00002.safetensors",
46
+ "model.predictive_blocks.10.attn.rotary_emb.freqs": "model-00001-of-00002.safetensors",
47
+ "model.predictive_blocks.10.attn.v_proj.bias": "model-00001-of-00002.safetensors",
48
+ "model.predictive_blocks.10.attn.v_proj.weight": "model-00001-of-00002.safetensors",
49
+ "model.predictive_blocks.10.norm_attn_qk.bias": "model-00001-of-00002.safetensors",
50
+ "model.predictive_blocks.10.norm_attn_qk.weight": "model-00001-of-00002.safetensors",
51
+ "model.predictive_blocks.10.norm_attn_v.bias": "model-00001-of-00002.safetensors",
52
+ "model.predictive_blocks.10.norm_attn_v.weight": "model-00001-of-00002.safetensors",
53
+ "model.predictive_blocks.10.norm_mlp.bias": "model-00001-of-00002.safetensors",
54
+ "model.predictive_blocks.10.norm_mlp.weight": "model-00001-of-00002.safetensors",
55
+ "model.predictive_blocks.10.mlp.fc_down.bias": "model-00002-of-00002.safetensors",
56
+ "model.predictive_blocks.10.mlp.fc_down.weight": "model-00002-of-00002.safetensors",
57
+ "model.predictive_blocks.10.mlp.fc_up.bias": "model-00001-of-00002.safetensors",
58
+ "model.predictive_blocks.10.mlp.fc_up.weight": "model-00001-of-00002.safetensors",
59
+ "model.predictive_blocks.11.attn.k_proj.weight": "model-00002-of-00002.safetensors",
60
+ "model.predictive_blocks.11.attn.o_proj.bias": "model-00002-of-00002.safetensors",
61
+ "model.predictive_blocks.11.attn.o_proj.weight": "model-00002-of-00002.safetensors",
62
+ "model.predictive_blocks.11.attn.q_proj.weight": "model-00002-of-00002.safetensors",
63
+ "model.predictive_blocks.11.attn.rotary_emb.freqs": "model-00002-of-00002.safetensors",
64
+ "model.predictive_blocks.11.attn.v_proj.bias": "model-00002-of-00002.safetensors",
65
+ "model.predictive_blocks.11.attn.v_proj.weight": "model-00002-of-00002.safetensors",
66
+ "model.predictive_blocks.11.norm_attn_qk.bias": "model-00002-of-00002.safetensors",
67
+ "model.predictive_blocks.11.norm_attn_qk.weight": "model-00002-of-00002.safetensors",
68
+ "model.predictive_blocks.11.norm_attn_v.bias": "model-00002-of-00002.safetensors",
69
+ "model.predictive_blocks.11.norm_attn_v.weight": "model-00002-of-00002.safetensors",
70
+ "model.predictive_blocks.11.norm_mlp.bias": "model-00002-of-00002.safetensors",
71
+ "model.predictive_blocks.11.norm_mlp.weight": "model-00002-of-00002.safetensors",
72
+ "model.predictive_blocks.11.mlp.fc_down.bias": "model-00002-of-00002.safetensors",
73
+ "model.predictive_blocks.11.mlp.fc_down.weight": "model-00002-of-00002.safetensors",
74
+ "model.predictive_blocks.11.mlp.fc_up.bias": "model-00002-of-00002.safetensors",
75
+ "model.predictive_blocks.11.mlp.fc_up.weight": "model-00002-of-00002.safetensors",
76
+ "model.predictive_blocks.2.attn.k_proj.weight": "model-00001-of-00002.safetensors",
77
+ "model.predictive_blocks.2.attn.o_proj.bias": "model-00001-of-00002.safetensors",
78
+ "model.predictive_blocks.2.attn.o_proj.weight": "model-00001-of-00002.safetensors",
79
+ "model.predictive_blocks.2.attn.q_proj.weight": "model-00001-of-00002.safetensors",
80
+ "model.predictive_blocks.2.attn.rotary_emb.freqs": "model-00001-of-00002.safetensors",
81
+ "model.predictive_blocks.2.attn.v_proj.bias": "model-00001-of-00002.safetensors",
82
+ "model.predictive_blocks.2.attn.v_proj.weight": "model-00001-of-00002.safetensors",
83
+ "model.predictive_blocks.2.norm_attn_qk.bias": "model-00001-of-00002.safetensors",
84
+ "model.predictive_blocks.2.norm_attn_qk.weight": "model-00001-of-00002.safetensors",
85
+ "model.predictive_blocks.2.norm_attn_v.bias": "model-00001-of-00002.safetensors",
86
+ "model.predictive_blocks.2.norm_attn_v.weight": "model-00001-of-00002.safetensors",
87
+ "model.predictive_blocks.2.norm_mlp.bias": "model-00001-of-00002.safetensors",
88
+ "model.predictive_blocks.2.norm_mlp.weight": "model-00001-of-00002.safetensors",
89
+ "model.predictive_blocks.2.mlp.fc_down.bias": "model-00001-of-00002.safetensors",
90
+ "model.predictive_blocks.2.mlp.fc_down.weight": "model-00001-of-00002.safetensors",
91
+ "model.predictive_blocks.2.mlp.fc_up.bias": "model-00001-of-00002.safetensors",
92
+ "model.predictive_blocks.2.mlp.fc_up.weight": "model-00001-of-00002.safetensors",
93
+ "model.predictive_blocks.3.attn.k_proj.weight": "model-00001-of-00002.safetensors",
94
+ "model.predictive_blocks.3.attn.o_proj.bias": "model-00001-of-00002.safetensors",
95
+ "model.predictive_blocks.3.attn.o_proj.weight": "model-00001-of-00002.safetensors",
96
+ "model.predictive_blocks.3.attn.q_proj.weight": "model-00001-of-00002.safetensors",
97
+ "model.predictive_blocks.3.attn.rotary_emb.freqs": "model-00001-of-00002.safetensors",
98
+ "model.predictive_blocks.3.attn.v_proj.bias": "model-00001-of-00002.safetensors",
99
+ "model.predictive_blocks.3.attn.v_proj.weight": "model-00001-of-00002.safetensors",
100
+ "model.predictive_blocks.3.norm_attn_qk.bias": "model-00001-of-00002.safetensors",
101
+ "model.predictive_blocks.3.norm_attn_qk.weight": "model-00001-of-00002.safetensors",
102
+ "model.predictive_blocks.3.norm_attn_v.bias": "model-00001-of-00002.safetensors",
103
+ "model.predictive_blocks.3.norm_attn_v.weight": "model-00001-of-00002.safetensors",
104
+ "model.predictive_blocks.3.norm_mlp.bias": "model-00001-of-00002.safetensors",
105
+ "model.predictive_blocks.3.norm_mlp.weight": "model-00001-of-00002.safetensors",
106
+ "model.predictive_blocks.3.mlp.fc_down.bias": "model-00001-of-00002.safetensors",
107
+ "model.predictive_blocks.3.mlp.fc_down.weight": "model-00001-of-00002.safetensors",
108
+ "model.predictive_blocks.3.mlp.fc_up.bias": "model-00001-of-00002.safetensors",
109
+ "model.predictive_blocks.3.mlp.fc_up.weight": "model-00001-of-00002.safetensors",
110
+ "model.predictive_blocks.4.attn.k_proj.weight": "model-00001-of-00002.safetensors",
111
+ "model.predictive_blocks.4.attn.o_proj.bias": "model-00001-of-00002.safetensors",
112
+ "model.predictive_blocks.4.attn.o_proj.weight": "model-00001-of-00002.safetensors",
113
+ "model.predictive_blocks.4.attn.q_proj.weight": "model-00001-of-00002.safetensors",
114
+ "model.predictive_blocks.4.attn.rotary_emb.freqs": "model-00001-of-00002.safetensors",
115
+ "model.predictive_blocks.4.attn.v_proj.bias": "model-00001-of-00002.safetensors",
116
+ "model.predictive_blocks.4.attn.v_proj.weight": "model-00001-of-00002.safetensors",
117
+ "model.predictive_blocks.4.norm_attn_qk.bias": "model-00001-of-00002.safetensors",
118
+ "model.predictive_blocks.4.norm_attn_qk.weight": "model-00001-of-00002.safetensors",
119
+ "model.predictive_blocks.4.norm_attn_v.bias": "model-00001-of-00002.safetensors",
120
+ "model.predictive_blocks.4.norm_attn_v.weight": "model-00001-of-00002.safetensors",
121
+ "model.predictive_blocks.4.norm_mlp.bias": "model-00001-of-00002.safetensors",
122
+ "model.predictive_blocks.4.norm_mlp.weight": "model-00001-of-00002.safetensors",
123
+ "model.predictive_blocks.4.mlp.fc_down.bias": "model-00001-of-00002.safetensors",
124
+ "model.predictive_blocks.4.mlp.fc_down.weight": "model-00001-of-00002.safetensors",
125
+ "model.predictive_blocks.4.mlp.fc_up.bias": "model-00001-of-00002.safetensors",
126
+ "model.predictive_blocks.4.mlp.fc_up.weight": "model-00001-of-00002.safetensors",
127
+ "model.predictive_blocks.5.attn.k_proj.weight": "model-00001-of-00002.safetensors",
128
+ "model.predictive_blocks.5.attn.o_proj.bias": "model-00001-of-00002.safetensors",
129
+ "model.predictive_blocks.5.attn.o_proj.weight": "model-00001-of-00002.safetensors",
130
+ "model.predictive_blocks.5.attn.q_proj.weight": "model-00001-of-00002.safetensors",
131
+ "model.predictive_blocks.5.attn.rotary_emb.freqs": "model-00001-of-00002.safetensors",
132
+ "model.predictive_blocks.5.attn.v_proj.bias": "model-00001-of-00002.safetensors",
133
+ "model.predictive_blocks.5.attn.v_proj.weight": "model-00001-of-00002.safetensors",
134
+ "model.predictive_blocks.5.norm_attn_qk.bias": "model-00001-of-00002.safetensors",
135
+ "model.predictive_blocks.5.norm_attn_qk.weight": "model-00001-of-00002.safetensors",
136
+ "model.predictive_blocks.5.norm_attn_v.bias": "model-00001-of-00002.safetensors",
137
+ "model.predictive_blocks.5.norm_attn_v.weight": "model-00001-of-00002.safetensors",
138
+ "model.predictive_blocks.5.norm_mlp.bias": "model-00001-of-00002.safetensors",
139
+ "model.predictive_blocks.5.norm_mlp.weight": "model-00001-of-00002.safetensors",
140
+ "model.predictive_blocks.5.mlp.fc_down.bias": "model-00001-of-00002.safetensors",
141
+ "model.predictive_blocks.5.mlp.fc_down.weight": "model-00001-of-00002.safetensors",
142
+ "model.predictive_blocks.5.mlp.fc_up.bias": "model-00001-of-00002.safetensors",
143
+ "model.predictive_blocks.5.mlp.fc_up.weight": "model-00001-of-00002.safetensors",
144
+ "model.predictive_blocks.6.attn.k_proj.weight": "model-00001-of-00002.safetensors",
145
+ "model.predictive_blocks.6.attn.o_proj.bias": "model-00001-of-00002.safetensors",
146
+ "model.predictive_blocks.6.attn.o_proj.weight": "model-00001-of-00002.safetensors",
147
+ "model.predictive_blocks.6.attn.q_proj.weight": "model-00001-of-00002.safetensors",
148
+ "model.predictive_blocks.6.attn.rotary_emb.freqs": "model-00001-of-00002.safetensors",
149
+ "model.predictive_blocks.6.attn.v_proj.bias": "model-00001-of-00002.safetensors",
150
+ "model.predictive_blocks.6.attn.v_proj.weight": "model-00001-of-00002.safetensors",
151
+ "model.predictive_blocks.6.norm_attn_qk.bias": "model-00001-of-00002.safetensors",
152
+ "model.predictive_blocks.6.norm_attn_qk.weight": "model-00001-of-00002.safetensors",
153
+ "model.predictive_blocks.6.norm_attn_v.bias": "model-00001-of-00002.safetensors",
154
+ "model.predictive_blocks.6.norm_attn_v.weight": "model-00001-of-00002.safetensors",
155
+ "model.predictive_blocks.6.norm_mlp.bias": "model-00001-of-00002.safetensors",
156
+ "model.predictive_blocks.6.norm_mlp.weight": "model-00001-of-00002.safetensors",
157
+ "model.predictive_blocks.6.mlp.fc_down.bias": "model-00001-of-00002.safetensors",
158
+ "model.predictive_blocks.6.mlp.fc_down.weight": "model-00001-of-00002.safetensors",
159
+ "model.predictive_blocks.6.mlp.fc_up.bias": "model-00001-of-00002.safetensors",
160
+ "model.predictive_blocks.6.mlp.fc_up.weight": "model-00001-of-00002.safetensors",
161
+ "model.predictive_blocks.7.attn.k_proj.weight": "model-00001-of-00002.safetensors",
162
+ "model.predictive_blocks.7.attn.o_proj.bias": "model-00001-of-00002.safetensors",
163
+ "model.predictive_blocks.7.attn.o_proj.weight": "model-00001-of-00002.safetensors",
164
+ "model.predictive_blocks.7.attn.q_proj.weight": "model-00001-of-00002.safetensors",
165
+ "model.predictive_blocks.7.attn.rotary_emb.freqs": "model-00001-of-00002.safetensors",
166
+ "model.predictive_blocks.7.attn.v_proj.bias": "model-00001-of-00002.safetensors",
167
+ "model.predictive_blocks.7.attn.v_proj.weight": "model-00001-of-00002.safetensors",
168
+ "model.predictive_blocks.7.norm_attn_qk.bias": "model-00001-of-00002.safetensors",
169
+ "model.predictive_blocks.7.norm_attn_qk.weight": "model-00001-of-00002.safetensors",
170
+ "model.predictive_blocks.7.norm_attn_v.bias": "model-00001-of-00002.safetensors",
171
+ "model.predictive_blocks.7.norm_attn_v.weight": "model-00001-of-00002.safetensors",
172
+ "model.predictive_blocks.7.norm_mlp.bias": "model-00001-of-00002.safetensors",
173
+ "model.predictive_blocks.7.norm_mlp.weight": "model-00001-of-00002.safetensors",
174
+ "model.predictive_blocks.7.mlp.fc_down.bias": "model-00001-of-00002.safetensors",
175
+ "model.predictive_blocks.7.mlp.fc_down.weight": "model-00001-of-00002.safetensors",
176
+ "model.predictive_blocks.7.mlp.fc_up.bias": "model-00001-of-00002.safetensors",
177
+ "model.predictive_blocks.7.mlp.fc_up.weight": "model-00001-of-00002.safetensors",
178
+ "model.predictive_blocks.8.attn.k_proj.weight": "model-00001-of-00002.safetensors",
179
+ "model.predictive_blocks.8.attn.o_proj.bias": "model-00001-of-00002.safetensors",
180
+ "model.predictive_blocks.8.attn.o_proj.weight": "model-00001-of-00002.safetensors",
181
+ "model.predictive_blocks.8.attn.q_proj.weight": "model-00001-of-00002.safetensors",
182
+ "model.predictive_blocks.8.attn.rotary_emb.freqs": "model-00001-of-00002.safetensors",
183
+ "model.predictive_blocks.8.attn.v_proj.bias": "model-00001-of-00002.safetensors",
184
+ "model.predictive_blocks.8.attn.v_proj.weight": "model-00001-of-00002.safetensors",
185
+ "model.predictive_blocks.8.norm_attn_qk.bias": "model-00001-of-00002.safetensors",
186
+ "model.predictive_blocks.8.norm_attn_qk.weight": "model-00001-of-00002.safetensors",
187
+ "model.predictive_blocks.8.norm_attn_v.bias": "model-00001-of-00002.safetensors",
188
+ "model.predictive_blocks.8.norm_attn_v.weight": "model-00001-of-00002.safetensors",
189
+ "model.predictive_blocks.8.norm_mlp.bias": "model-00001-of-00002.safetensors",
190
+ "model.predictive_blocks.8.norm_mlp.weight": "model-00001-of-00002.safetensors",
191
+ "model.predictive_blocks.8.mlp.fc_down.bias": "model-00001-of-00002.safetensors",
192
+ "model.predictive_blocks.8.mlp.fc_down.weight": "model-00001-of-00002.safetensors",
193
+ "model.predictive_blocks.8.mlp.fc_up.bias": "model-00001-of-00002.safetensors",
194
+ "model.predictive_blocks.8.mlp.fc_up.weight": "model-00001-of-00002.safetensors",
195
+ "model.predictive_blocks.9.attn.k_proj.weight": "model-00001-of-00002.safetensors",
196
+ "model.predictive_blocks.9.attn.o_proj.bias": "model-00001-of-00002.safetensors",
197
+ "model.predictive_blocks.9.attn.o_proj.weight": "model-00001-of-00002.safetensors",
198
+ "model.predictive_blocks.9.attn.q_proj.weight": "model-00001-of-00002.safetensors",
199
+ "model.predictive_blocks.9.attn.rotary_emb.freqs": "model-00001-of-00002.safetensors",
200
+ "model.predictive_blocks.9.attn.v_proj.bias": "model-00001-of-00002.safetensors",
201
+ "model.predictive_blocks.9.attn.v_proj.weight": "model-00001-of-00002.safetensors",
202
+ "model.predictive_blocks.9.norm_attn_qk.bias": "model-00001-of-00002.safetensors",
203
+ "model.predictive_blocks.9.norm_attn_qk.weight": "model-00001-of-00002.safetensors",
204
+ "model.predictive_blocks.9.norm_attn_v.bias": "model-00001-of-00002.safetensors",
205
+ "model.predictive_blocks.9.norm_attn_v.weight": "model-00001-of-00002.safetensors",
206
+ "model.predictive_blocks.9.norm_mlp.bias": "model-00001-of-00002.safetensors",
207
+ "model.predictive_blocks.9.norm_mlp.weight": "model-00001-of-00002.safetensors",
208
+ "model.predictive_blocks.9.mlp.fc_down.bias": "model-00001-of-00002.safetensors",
209
+ "model.predictive_blocks.9.mlp.fc_down.weight": "model-00001-of-00002.safetensors",
210
+ "model.predictive_blocks.9.mlp.fc_up.bias": "model-00001-of-00002.safetensors",
211
+ "model.predictive_blocks.9.mlp.fc_up.weight": "model-00001-of-00002.safetensors",
212
+ "model.norm_out.bias": "model-00002-of-00002.safetensors",
213
+ "model.norm_out.weight": "model-00002-of-00002.safetensors",
214
+ "model.feature_blocks.0.attn.k_proj.weight": "model-00001-of-00002.safetensors",
215
+ "model.feature_blocks.0.attn.o_proj.bias": "model-00001-of-00002.safetensors",
216
+ "model.feature_blocks.0.attn.o_proj.weight": "model-00001-of-00002.safetensors",
217
+ "model.feature_blocks.0.attn.q_proj.weight": "model-00001-of-00002.safetensors",
218
+ "model.feature_blocks.0.attn.rotary_emb.freqs": "model-00001-of-00002.safetensors",
219
+ "model.feature_blocks.0.attn.v_proj.bias": "model-00001-of-00002.safetensors",
220
+ "model.feature_blocks.0.attn.v_proj.weight": "model-00001-of-00002.safetensors",
221
+ "model.feature_blocks.0.norm_attn.bias": "model-00001-of-00002.safetensors",
222
+ "model.feature_blocks.0.norm_attn.weight": "model-00001-of-00002.safetensors",
223
+ "model.feature_blocks.0.norm_mlp.bias": "model-00001-of-00002.safetensors",
224
+ "model.feature_blocks.0.norm_mlp.weight": "model-00001-of-00002.safetensors",
225
+ "model.feature_blocks.0.mlp.fc_down.bias": "model-00001-of-00002.safetensors",
226
+ "model.feature_blocks.0.mlp.fc_down.weight": "model-00001-of-00002.safetensors",
227
+ "model.feature_blocks.0.mlp.fc_up.bias": "model-00001-of-00002.safetensors",
228
+ "model.feature_blocks.0.mlp.fc_up.weight": "model-00001-of-00002.safetensors",
229
+ "model.feature_blocks.1.attn.k_proj.weight": "model-00001-of-00002.safetensors",
230
+ "model.feature_blocks.1.attn.o_proj.bias": "model-00001-of-00002.safetensors",
231
+ "model.feature_blocks.1.attn.o_proj.weight": "model-00001-of-00002.safetensors",
232
+ "model.feature_blocks.1.attn.q_proj.weight": "model-00001-of-00002.safetensors",
233
+ "model.feature_blocks.1.attn.rotary_emb.freqs": "model-00001-of-00002.safetensors",
234
+ "model.feature_blocks.1.attn.v_proj.bias": "model-00001-of-00002.safetensors",
235
+ "model.feature_blocks.1.attn.v_proj.weight": "model-00001-of-00002.safetensors",
236
+ "model.feature_blocks.1.norm_attn.bias": "model-00001-of-00002.safetensors",
237
+ "model.feature_blocks.1.norm_attn.weight": "model-00001-of-00002.safetensors",
238
+ "model.feature_blocks.1.norm_mlp.bias": "model-00001-of-00002.safetensors",
239
+ "model.feature_blocks.1.norm_mlp.weight": "model-00001-of-00002.safetensors",
240
+ "model.feature_blocks.1.mlp.fc_down.bias": "model-00001-of-00002.safetensors",
241
+ "model.feature_blocks.1.mlp.fc_down.weight": "model-00001-of-00002.safetensors",
242
+ "model.feature_blocks.1.mlp.fc_up.bias": "model-00001-of-00002.safetensors",
243
+ "model.feature_blocks.1.mlp.fc_up.weight": "model-00001-of-00002.safetensors",
244
+ "model.feature_blocks.10.attn.k_proj.weight": "model-00001-of-00002.safetensors",
245
+ "model.feature_blocks.10.attn.o_proj.bias": "model-00001-of-00002.safetensors",
246
+ "model.feature_blocks.10.attn.o_proj.weight": "model-00001-of-00002.safetensors",
247
+ "model.feature_blocks.10.attn.q_proj.weight": "model-00001-of-00002.safetensors",
248
+ "model.feature_blocks.10.attn.rotary_emb.freqs": "model-00001-of-00002.safetensors",
249
+ "model.feature_blocks.10.attn.v_proj.bias": "model-00001-of-00002.safetensors",
250
+ "model.feature_blocks.10.attn.v_proj.weight": "model-00001-of-00002.safetensors",
251
+ "model.feature_blocks.10.norm_attn.bias": "model-00001-of-00002.safetensors",
252
+ "model.feature_blocks.10.norm_attn.weight": "model-00001-of-00002.safetensors",
253
+ "model.feature_blocks.10.norm_mlp.bias": "model-00001-of-00002.safetensors",
254
+ "model.feature_blocks.10.norm_mlp.weight": "model-00001-of-00002.safetensors",
255
+ "model.feature_blocks.10.mlp.fc_down.bias": "model-00001-of-00002.safetensors",
256
+ "model.feature_blocks.10.mlp.fc_down.weight": "model-00001-of-00002.safetensors",
257
+ "model.feature_blocks.10.mlp.fc_up.bias": "model-00001-of-00002.safetensors",
258
+ "model.feature_blocks.10.mlp.fc_up.weight": "model-00001-of-00002.safetensors",
259
+ "model.feature_blocks.11.attn.k_proj.weight": "model-00001-of-00002.safetensors",
260
+ "model.feature_blocks.11.attn.o_proj.bias": "model-00001-of-00002.safetensors",
261
+ "model.feature_blocks.11.attn.o_proj.weight": "model-00001-of-00002.safetensors",
262
+ "model.feature_blocks.11.attn.q_proj.weight": "model-00001-of-00002.safetensors",
263
+ "model.feature_blocks.11.attn.rotary_emb.freqs": "model-00001-of-00002.safetensors",
264
+ "model.feature_blocks.11.attn.v_proj.bias": "model-00001-of-00002.safetensors",
265
+ "model.feature_blocks.11.attn.v_proj.weight": "model-00001-of-00002.safetensors",
266
+ "model.feature_blocks.11.norm_attn.bias": "model-00001-of-00002.safetensors",
267
+ "model.feature_blocks.11.norm_attn.weight": "model-00001-of-00002.safetensors",
268
+ "model.feature_blocks.11.norm_mlp.bias": "model-00001-of-00002.safetensors",
269
+ "model.feature_blocks.11.norm_mlp.weight": "model-00001-of-00002.safetensors",
270
+ "model.feature_blocks.11.mlp.fc_down.bias": "model-00001-of-00002.safetensors",
271
+ "model.feature_blocks.11.mlp.fc_down.weight": "model-00001-of-00002.safetensors",
272
+ "model.feature_blocks.11.mlp.fc_up.bias": "model-00001-of-00002.safetensors",
273
+ "model.feature_blocks.11.mlp.fc_up.weight": "model-00001-of-00002.safetensors",
274
+ "model.feature_blocks.2.attn.k_proj.weight": "model-00001-of-00002.safetensors",
275
+ "model.feature_blocks.2.attn.o_proj.bias": "model-00001-of-00002.safetensors",
276
+ "model.feature_blocks.2.attn.o_proj.weight": "model-00001-of-00002.safetensors",
277
+ "model.feature_blocks.2.attn.q_proj.weight": "model-00001-of-00002.safetensors",
278
+ "model.feature_blocks.2.attn.rotary_emb.freqs": "model-00001-of-00002.safetensors",
279
+ "model.feature_blocks.2.attn.v_proj.bias": "model-00001-of-00002.safetensors",
280
+ "model.feature_blocks.2.attn.v_proj.weight": "model-00001-of-00002.safetensors",
281
+ "model.feature_blocks.2.norm_attn.bias": "model-00001-of-00002.safetensors",
282
+ "model.feature_blocks.2.norm_attn.weight": "model-00001-of-00002.safetensors",
283
+ "model.feature_blocks.2.norm_mlp.bias": "model-00001-of-00002.safetensors",
284
+ "model.feature_blocks.2.norm_mlp.weight": "model-00001-of-00002.safetensors",
285
+ "model.feature_blocks.2.mlp.fc_down.bias": "model-00001-of-00002.safetensors",
286
+ "model.feature_blocks.2.mlp.fc_down.weight": "model-00001-of-00002.safetensors",
287
+ "model.feature_blocks.2.mlp.fc_up.bias": "model-00001-of-00002.safetensors",
288
+ "model.feature_blocks.2.mlp.fc_up.weight": "model-00001-of-00002.safetensors",
289
+ "model.feature_blocks.3.attn.k_proj.weight": "model-00001-of-00002.safetensors",
290
+ "model.feature_blocks.3.attn.o_proj.bias": "model-00001-of-00002.safetensors",
291
+ "model.feature_blocks.3.attn.o_proj.weight": "model-00001-of-00002.safetensors",
292
+ "model.feature_blocks.3.attn.q_proj.weight": "model-00001-of-00002.safetensors",
293
+ "model.feature_blocks.3.attn.rotary_emb.freqs": "model-00001-of-00002.safetensors",
294
+ "model.feature_blocks.3.attn.v_proj.bias": "model-00001-of-00002.safetensors",
295
+ "model.feature_blocks.3.attn.v_proj.weight": "model-00001-of-00002.safetensors",
296
+ "model.feature_blocks.3.norm_attn.bias": "model-00001-of-00002.safetensors",
297
+ "model.feature_blocks.3.norm_attn.weight": "model-00001-of-00002.safetensors",
298
+ "model.feature_blocks.3.norm_mlp.bias": "model-00001-of-00002.safetensors",
299
+ "model.feature_blocks.3.norm_mlp.weight": "model-00001-of-00002.safetensors",
300
+ "model.feature_blocks.3.mlp.fc_down.bias": "model-00001-of-00002.safetensors",
301
+ "model.feature_blocks.3.mlp.fc_down.weight": "model-00001-of-00002.safetensors",
302
+ "model.feature_blocks.3.mlp.fc_up.bias": "model-00001-of-00002.safetensors",
303
+ "model.feature_blocks.3.mlp.fc_up.weight": "model-00001-of-00002.safetensors",
304
+ "model.feature_blocks.4.attn.k_proj.weight": "model-00001-of-00002.safetensors",
305
+ "model.feature_blocks.4.attn.o_proj.bias": "model-00001-of-00002.safetensors",
306
+ "model.feature_blocks.4.attn.o_proj.weight": "model-00001-of-00002.safetensors",
307
+ "model.feature_blocks.4.attn.q_proj.weight": "model-00001-of-00002.safetensors",
308
+ "model.feature_blocks.4.attn.rotary_emb.freqs": "model-00001-of-00002.safetensors",
309
+ "model.feature_blocks.4.attn.v_proj.bias": "model-00001-of-00002.safetensors",
310
+ "model.feature_blocks.4.attn.v_proj.weight": "model-00001-of-00002.safetensors",
311
+ "model.feature_blocks.4.norm_attn.bias": "model-00001-of-00002.safetensors",
312
+ "model.feature_blocks.4.norm_attn.weight": "model-00001-of-00002.safetensors",
313
+ "model.feature_blocks.4.norm_mlp.bias": "model-00001-of-00002.safetensors",
314
+ "model.feature_blocks.4.norm_mlp.weight": "model-00001-of-00002.safetensors",
315
+ "model.feature_blocks.4.mlp.fc_down.bias": "model-00001-of-00002.safetensors",
316
+ "model.feature_blocks.4.mlp.fc_down.weight": "model-00001-of-00002.safetensors",
317
+ "model.feature_blocks.4.mlp.fc_up.bias": "model-00001-of-00002.safetensors",
318
+ "model.feature_blocks.4.mlp.fc_up.weight": "model-00001-of-00002.safetensors",
319
+ "model.feature_blocks.5.attn.k_proj.weight": "model-00001-of-00002.safetensors",
320
+ "model.feature_blocks.5.attn.o_proj.bias": "model-00001-of-00002.safetensors",
321
+ "model.feature_blocks.5.attn.o_proj.weight": "model-00001-of-00002.safetensors",
322
+ "model.feature_blocks.5.attn.q_proj.weight": "model-00001-of-00002.safetensors",
323
+ "model.feature_blocks.5.attn.rotary_emb.freqs": "model-00001-of-00002.safetensors",
324
+ "model.feature_blocks.5.attn.v_proj.bias": "model-00001-of-00002.safetensors",
325
+ "model.feature_blocks.5.attn.v_proj.weight": "model-00001-of-00002.safetensors",
326
+ "model.feature_blocks.5.norm_attn.bias": "model-00001-of-00002.safetensors",
327
+ "model.feature_blocks.5.norm_attn.weight": "model-00001-of-00002.safetensors",
328
+ "model.feature_blocks.5.norm_mlp.bias": "model-00001-of-00002.safetensors",
329
+ "model.feature_blocks.5.norm_mlp.weight": "model-00001-of-00002.safetensors",
330
+ "model.feature_blocks.5.mlp.fc_down.bias": "model-00001-of-00002.safetensors",
331
+ "model.feature_blocks.5.mlp.fc_down.weight": "model-00001-of-00002.safetensors",
332
+ "model.feature_blocks.5.mlp.fc_up.bias": "model-00001-of-00002.safetensors",
333
+ "model.feature_blocks.5.mlp.fc_up.weight": "model-00001-of-00002.safetensors",
334
+ "model.feature_blocks.6.attn.k_proj.weight": "model-00001-of-00002.safetensors",
335
+ "model.feature_blocks.6.attn.o_proj.bias": "model-00001-of-00002.safetensors",
336
+ "model.feature_blocks.6.attn.o_proj.weight": "model-00001-of-00002.safetensors",
337
+ "model.feature_blocks.6.attn.q_proj.weight": "model-00001-of-00002.safetensors",
338
+ "model.feature_blocks.6.attn.rotary_emb.freqs": "model-00001-of-00002.safetensors",
339
+ "model.feature_blocks.6.attn.v_proj.bias": "model-00001-of-00002.safetensors",
340
+ "model.feature_blocks.6.attn.v_proj.weight": "model-00001-of-00002.safetensors",
341
+ "model.feature_blocks.6.norm_attn.bias": "model-00001-of-00002.safetensors",
342
+ "model.feature_blocks.6.norm_attn.weight": "model-00001-of-00002.safetensors",
343
+ "model.feature_blocks.6.norm_mlp.bias": "model-00001-of-00002.safetensors",
344
+ "model.feature_blocks.6.norm_mlp.weight": "model-00001-of-00002.safetensors",
345
+ "model.feature_blocks.6.mlp.fc_down.bias": "model-00001-of-00002.safetensors",
346
+ "model.feature_blocks.6.mlp.fc_down.weight": "model-00001-of-00002.safetensors",
347
+ "model.feature_blocks.6.mlp.fc_up.bias": "model-00001-of-00002.safetensors",
348
+ "model.feature_blocks.6.mlp.fc_up.weight": "model-00001-of-00002.safetensors",
349
+ "model.feature_blocks.7.attn.k_proj.weight": "model-00001-of-00002.safetensors",
350
+ "model.feature_blocks.7.attn.o_proj.bias": "model-00001-of-00002.safetensors",
351
+ "model.feature_blocks.7.attn.o_proj.weight": "model-00001-of-00002.safetensors",
352
+ "model.feature_blocks.7.attn.q_proj.weight": "model-00001-of-00002.safetensors",
353
+ "model.feature_blocks.7.attn.rotary_emb.freqs": "model-00001-of-00002.safetensors",
354
+ "model.feature_blocks.7.attn.v_proj.bias": "model-00001-of-00002.safetensors",
355
+ "model.feature_blocks.7.attn.v_proj.weight": "model-00001-of-00002.safetensors",
356
+ "model.feature_blocks.7.norm_attn.bias": "model-00001-of-00002.safetensors",
357
+ "model.feature_blocks.7.norm_attn.weight": "model-00001-of-00002.safetensors",
358
+ "model.feature_blocks.7.norm_mlp.bias": "model-00001-of-00002.safetensors",
359
+ "model.feature_blocks.7.norm_mlp.weight": "model-00001-of-00002.safetensors",
360
+ "model.feature_blocks.7.mlp.fc_down.bias": "model-00001-of-00002.safetensors",
361
+ "model.feature_blocks.7.mlp.fc_down.weight": "model-00001-of-00002.safetensors",
362
+ "model.feature_blocks.7.mlp.fc_up.bias": "model-00001-of-00002.safetensors",
363
+ "model.feature_blocks.7.mlp.fc_up.weight": "model-00001-of-00002.safetensors",
364
+ "model.feature_blocks.8.attn.k_proj.weight": "model-00001-of-00002.safetensors",
365
+ "model.feature_blocks.8.attn.o_proj.bias": "model-00001-of-00002.safetensors",
366
+ "model.feature_blocks.8.attn.o_proj.weight": "model-00001-of-00002.safetensors",
367
+ "model.feature_blocks.8.attn.q_proj.weight": "model-00001-of-00002.safetensors",
368
+ "model.feature_blocks.8.attn.rotary_emb.freqs": "model-00001-of-00002.safetensors",
369
+ "model.feature_blocks.8.attn.v_proj.bias": "model-00001-of-00002.safetensors",
370
+ "model.feature_blocks.8.attn.v_proj.weight": "model-00001-of-00002.safetensors",
371
+ "model.feature_blocks.8.norm_attn.bias": "model-00001-of-00002.safetensors",
372
+ "model.feature_blocks.8.norm_attn.weight": "model-00001-of-00002.safetensors",
373
+ "model.feature_blocks.8.norm_mlp.bias": "model-00001-of-00002.safetensors",
374
+ "model.feature_blocks.8.norm_mlp.weight": "model-00001-of-00002.safetensors",
375
+ "model.feature_blocks.8.mlp.fc_down.bias": "model-00001-of-00002.safetensors",
376
+ "model.feature_blocks.8.mlp.fc_down.weight": "model-00001-of-00002.safetensors",
377
+ "model.feature_blocks.8.mlp.fc_up.bias": "model-00001-of-00002.safetensors",
378
+ "model.feature_blocks.8.mlp.fc_up.weight": "model-00001-of-00002.safetensors",
379
+ "model.feature_blocks.9.attn.k_proj.weight": "model-00001-of-00002.safetensors",
380
+ "model.feature_blocks.9.attn.o_proj.bias": "model-00001-of-00002.safetensors",
381
+ "model.feature_blocks.9.attn.o_proj.weight": "model-00001-of-00002.safetensors",
382
+ "model.feature_blocks.9.attn.q_proj.weight": "model-00001-of-00002.safetensors",
383
+ "model.feature_blocks.9.attn.rotary_emb.freqs": "model-00001-of-00002.safetensors",
384
+ "model.feature_blocks.9.attn.v_proj.bias": "model-00001-of-00002.safetensors",
385
+ "model.feature_blocks.9.attn.v_proj.weight": "model-00001-of-00002.safetensors",
386
+ "model.feature_blocks.9.norm_attn.bias": "model-00001-of-00002.safetensors",
387
+ "model.feature_blocks.9.norm_attn.weight": "model-00001-of-00002.safetensors",
388
+ "model.feature_blocks.9.norm_mlp.bias": "model-00001-of-00002.safetensors",
389
+ "model.feature_blocks.9.norm_mlp.weight": "model-00001-of-00002.safetensors",
390
+ "model.feature_blocks.9.mlp.fc_down.bias": "model-00001-of-00002.safetensors",
391
+ "model.feature_blocks.9.mlp.fc_down.weight": "model-00001-of-00002.safetensors",
392
+ "model.feature_blocks.9.mlp.fc_up.bias": "model-00001-of-00002.safetensors",
393
+ "model.feature_blocks.9.mlp.fc_up.weight": "model-00001-of-00002.safetensors"
394
+ }
395
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|endoftext|>",
3
+ "eos_token": "<|endoftext|>",
4
+ "pad_token": "<|endoftext|>",
5
+ "unk_token": "<|endoftext|>"
6
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "50256": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ }
12
+ },
13
+ "bos_token": "<|endoftext|>",
14
+ "clean_up_tokenization_spaces": false,
15
+ "eos_token": "<|endoftext|>",
16
+ "extra_special_tokens": {},
17
+ "model_max_length": 1024,
18
+ "pad_token": "<|endoftext|>",
19
+ "tokenizer_class": "GPT2Tokenizer",
20
+ "unk_token": "<|endoftext|>"
21
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff