merterbak commited on
Commit
c4c5b1d
·
verified ·
1 Parent(s): 330c5ea

Upload folder using huggingface_hub

Browse files
config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "SeedForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_seed.SeedConfig",
7
+ "AutoModelForCausalLM": "modeling_seed.SeedForCausalLM"
8
+ },
9
+ "model_type": "seed",
10
+ "vocab_size": 64000,
11
+ "n_embd": 1024,
12
+ "n_layer": 28,
13
+ "n_head": 16,
14
+ "n_kv_head": 8,
15
+ "head_dim": 128,
16
+ "intermediate_size": 3072,
17
+ "block_size": 4096,
18
+ "bias": false,
19
+ "dropout": 0.0,
20
+ "rms_norm_eps": 1e-6,
21
+ "rope_theta": 1000000.0,
22
+ "rope_scaling_type": "none",
23
+ "rope_scaling_factor": 1.0,
24
+ "tie_word_embeddings": true,
25
+ "torch_dtype": "float32",
26
+ "transformers_version": "4.57.3"
27
+ }
configuration_seed.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class SeedConfig(PretrainedConfig):
5
+
6
+ model_type = "seed"
7
+
8
+ def __init__(
9
+ self,
10
+ vocab_size: int = 64000,
11
+ n_embd: int = 1024,
12
+ n_layer: int = 28,
13
+ n_head: int = 16,
14
+ n_kv_head: int = 8,
15
+ head_dim: int = 128,
16
+ intermediate_size: int = 3072,
17
+ block_size: int = 4096,
18
+ bias: bool = False,
19
+ dropout: float = 0.0,
20
+ rope_theta: float = 1000000.0,
21
+ rope_scaling_type: str = "none",
22
+ rope_scaling_factor: float = 1.0,
23
+ rms_norm_eps: float = 1e-6,
24
+ tie_word_embeddings: bool = True,
25
+ **kwargs,
26
+ ):
27
+ self.vocab_size = vocab_size
28
+ self.n_embd = n_embd
29
+ self.n_layer = n_layer
30
+ self.n_head = n_head
31
+ self.n_kv_head = n_kv_head
32
+ self.head_dim = head_dim
33
+ self.intermediate_size = intermediate_size
34
+ self.block_size = block_size
35
+ self.bias = bias
36
+ self.dropout = dropout
37
+ self.rope_theta = rope_theta
38
+ self.rope_scaling_type = rope_scaling_type
39
+ self.rope_scaling_factor = rope_scaling_factor
40
+ self.rms_norm_eps = rms_norm_eps
41
+
42
+ # Transformers compatibility aliases
43
+ self.hidden_size = n_embd
44
+ self.num_hidden_layers = n_layer
45
+ self.num_attention_heads = n_head
46
+ self.num_key_value_heads = n_kv_head
47
+
48
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
generation_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 2,
3
+ "eos_token_id": 1,
4
+ "pad_token_id": 1,
5
+ "unk_token_id": 0,
6
+ "do_sample": true,
7
+ "temperature": 1.0,
8
+ "top_k": 20,
9
+ "top_p": 0.95,
10
+ "max_new_tokens": 256,
11
+ "transformers_version": "4.57.3"
12
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7fb2e33c84f461a755662e675e3ac45531b0c07a0a77cc38b12f78d37cb451eb
3
+ size 2024045696
modeling_seed.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from transformers import PreTrainedModel
6
+ from transformers.cache_utils import DynamicCache
7
+ from transformers.generation import GenerationMixin
8
+ from transformers.modeling_outputs import CausalLMOutputWithPast
9
+ from .configuration_seed import SeedConfig
10
+
11
+
12
+ class RMSNorm(nn.Module):
13
+
14
+ def __init__(self, dim, eps=1e-6):
15
+ super().__init__()
16
+ self.epsilon = eps
17
+ self.weight = nn.Parameter(torch.ones(dim))
18
+
19
+ def forward(self, x):
20
+ x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.epsilon) * self.weight
21
+ return x
22
+
23
+
24
+ class RoPEEmbedding(nn.Module):
25
+
26
+ def __init__(self, config, device=None):
27
+ super().__init__()
28
+ self.config = config
29
+ assert config.n_embd % config.n_head == 0
30
+ self.head_dim = config.head_dim
31
+ self.rope_scaling_type = str(getattr(config, "rope_scaling_type", "none"))
32
+ self.rope_scaling_factor = float(getattr(config, "rope_scaling_factor", 1.0))
33
+
34
+ base = float(config.rope_theta)
35
+ self.position_scale = 1.0
36
+ self.attention_scaling = 1.0
37
+
38
+ if self.rope_scaling_type == "none" or self.rope_scaling_factor == 1.0:
39
+ pass
40
+ elif self.rope_scaling_type == "yarn":
41
+ base = base * (self.rope_scaling_factor ** (self.head_dim / (self.head_dim - 2.0)))
42
+ self.attention_scaling = 0.1 * math.log(self.rope_scaling_factor) + 1.0
43
+ elif self.rope_scaling_type == "ntk":
44
+ base = base * (self.rope_scaling_factor ** (self.head_dim / (self.head_dim - 2.0)))
45
+ else:
46
+ raise ValueError(f"Unknown rope_scaling_type={self.rope_scaling_type!r}")
47
+
48
+ self.base = base
49
+
50
+ inv_freq = 1.0 / (
51
+ self.base
52
+ ** (torch.arange(0, self.head_dim, 2, dtype=torch.float32, device=device) / float(self.head_dim))
53
+ )
54
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
55
+
56
+ def forward(self, x, position_ids):
57
+ dtype = x.dtype
58
+
59
+ pos = position_ids.float().unsqueeze(-1) * self.position_scale
60
+ inv_freq = self.inv_freq.unsqueeze(0).unsqueeze(0)
61
+ freqs = pos * inv_freq
62
+ emb = torch.cat([freqs, freqs], dim=-1)
63
+
64
+ cos = (emb.cos() * self.attention_scaling).to(dtype)
65
+ sin = (emb.sin() * self.attention_scaling).to(dtype)
66
+ return cos, sin
67
+
68
+
69
+ def rotate_half(x):
70
+ x1, x2 = x.chunk(2, dim=-1)
71
+ return torch.cat([-x2, x1], dim=-1)
72
+
73
+
74
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
75
+ cos = cos.unsqueeze(unsqueeze_dim)
76
+ sin = sin.unsqueeze(unsqueeze_dim)
77
+ q = (q * cos) + (rotate_half(q) * sin)
78
+ k = (k * cos) + (rotate_half(k) * sin)
79
+ return q, k
80
+
81
+
82
+ class GQA(nn.Module):
83
+
84
+ def __init__(self, config, layer_idx):
85
+ super().__init__()
86
+ self.layer_idx = int(layer_idx)
87
+ self.n_head = config.n_head
88
+ self.n_kv_head = int(getattr(config, "n_kv_head", config.n_head))
89
+ self.n_embd = config.n_embd
90
+ self.block_size = int(config.block_size)
91
+ assert 1 <= self.n_kv_head <= self.n_head
92
+ assert self.n_head % self.n_kv_head == 0
93
+
94
+ self.head_dim = config.head_dim
95
+ q_dim = self.n_head * self.head_dim
96
+ kv_dim = self.n_kv_head * self.head_dim
97
+
98
+ self.q_proj = nn.Linear(self.n_embd, q_dim, bias=False)
99
+ self.k_proj = nn.Linear(self.n_embd, kv_dim, bias=False)
100
+ self.v_proj = nn.Linear(self.n_embd, kv_dim, bias=False)
101
+ self.o_proj = nn.Linear(q_dim, self.n_embd, bias=False)
102
+
103
+ self.q_norm = RMSNorm(self.head_dim)
104
+ self.k_norm = RMSNorm(self.head_dim)
105
+
106
+ def forward(self, x, cos, sin, past_key_values=None):
107
+ B, T, C = x.shape
108
+ q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
109
+ k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
110
+ v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
111
+ q = self.q_norm(q)
112
+ k = self.k_norm(k)
113
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
114
+
115
+ past_len = 0
116
+ if past_key_values is not None:
117
+ past_len = past_key_values.get_seq_length(self.layer_idx)
118
+ k, v = past_key_values.update(k, v, self.layer_idx)
119
+
120
+ if self.n_kv_head != self.n_head:
121
+ repeat_factor = self.n_head // self.n_kv_head
122
+ k = k.repeat_interleave(repeat_factor, dim=1)
123
+ v = v.repeat_interleave(repeat_factor, dim=1)
124
+
125
+ if past_len == 0:
126
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True)
127
+ else:
128
+ Tk = int(k.size(2))
129
+ query_pos = past_len + torch.arange(T, device=x.device)
130
+ key_pos = torch.arange(Tk, device=x.device)
131
+ causal_mask = key_pos.unsqueeze(0) <= query_pos.unsqueeze(1)
132
+ attn_mask = torch.zeros((1, 1, T, Tk), device=x.device, dtype=q.dtype)
133
+ attn_mask = attn_mask.masked_fill(~causal_mask.view(1, 1, T, Tk), torch.finfo(q.dtype).min)
134
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=False)
135
+
136
+ y = y.transpose(1, 2).contiguous().view(B, T, -1)
137
+ y = self.o_proj(y)
138
+ return y
139
+
140
+
141
+ class SwiGLU(nn.Module):
142
+
143
+ def __init__(self, config):
144
+ super().__init__()
145
+ self.n_embd = config.n_embd
146
+ hidden_dim = getattr(config, "intermediate_size", None)
147
+ if hidden_dim is None:
148
+ hidden_dim = int(4 * self.n_embd * 2 / 3)
149
+ hidden_dim = (hidden_dim + 255) // 256 * 256
150
+
151
+ self.gate_proj = nn.Linear(self.n_embd, hidden_dim, bias=config.bias)
152
+ self.up_proj = nn.Linear(self.n_embd, hidden_dim, bias=config.bias)
153
+ self.down_proj = nn.Linear(hidden_dim, self.n_embd, bias=config.bias)
154
+
155
+ def forward(self, x):
156
+ gate = self.gate_proj(x)
157
+ up = self.up_proj(x)
158
+ x = self.down_proj(F.silu(gate) * up)
159
+ return x
160
+
161
+
162
+ class DecoderLayer(nn.Module):
163
+
164
+ def __init__(self, config, layer_idx):
165
+ super().__init__()
166
+ self.input_norm = RMSNorm(config.n_embd, eps=config.rms_norm_eps)
167
+ self.post_attn_norm = RMSNorm(config.n_embd, eps=config.rms_norm_eps)
168
+ self.attn = GQA(config, layer_idx=layer_idx)
169
+ self.mlp = SwiGLU(config)
170
+
171
+ def forward(self, x, cos, sin, past_key_values=None):
172
+ residual = x
173
+ x = self.input_norm(x)
174
+ x = self.attn(x, cos, sin, past_key_values=past_key_values)
175
+ x = residual + x
176
+
177
+ residual = x
178
+ x = self.post_attn_norm(x)
179
+ x = self.mlp(x)
180
+ x = residual + x
181
+ return x
182
+
183
+
184
+ class SeedPreTrainedModel(PreTrainedModel):
185
+ config_class = SeedConfig
186
+ base_model_prefix = "model"
187
+ _no_split_modules = ["DecoderLayer"]
188
+ _skip_keys_device_placement = ["past_key_values"]
189
+ _supports_sdpa = True
190
+
191
+
192
+ class SeedForCausalLM(SeedPreTrainedModel, GenerationMixin):
193
+ _tied_weights_keys = ["lm_head.weight"]
194
+
195
+ def __init__(self, config):
196
+ super().__init__(config)
197
+ self.config = config
198
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
199
+ self.layers = nn.ModuleList([DecoderLayer(config, layer_idx=i) for i in range(config.n_layer)])
200
+ self.norm = RMSNorm(config.n_embd, eps=config.rms_norm_eps)
201
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
202
+ self.rope = RoPEEmbedding(config)
203
+ self.post_init()
204
+
205
+ def get_input_embeddings(self):
206
+ return self.wte
207
+
208
+ def set_input_embeddings(self, value):
209
+ self.wte = value
210
+
211
+ def get_output_embeddings(self):
212
+ return self.lm_head
213
+
214
+ def set_output_embeddings(self, new_embeddings):
215
+ self.lm_head = new_embeddings
216
+
217
+ def forward(
218
+ self,
219
+ input_ids=None,
220
+ attention_mask=None,
221
+ position_ids=None,
222
+ past_key_values=None,
223
+ inputs_embeds=None,
224
+ labels=None,
225
+ use_cache=None,
226
+ token_type_ids=None,
227
+ **kwargs
228
+ ):
229
+ if inputs_embeds is None:
230
+ inputs_embeds = self.wte(input_ids)
231
+
232
+ B, T = inputs_embeds.shape[:2]
233
+
234
+ if use_cache and past_key_values is None:
235
+ past_key_values = DynamicCache()
236
+
237
+ if position_ids is None:
238
+ past_seen = past_key_values.get_seq_length() if past_key_values is not None else 0
239
+ position_ids = torch.arange(past_seen, past_seen + T, device=inputs_embeds.device).unsqueeze(0).expand(B, T)
240
+
241
+ cos, sin = self.rope(inputs_embeds, position_ids)
242
+
243
+ x = inputs_embeds
244
+ for layer in self.layers:
245
+ x = layer(x, cos, sin, past_key_values=past_key_values)
246
+
247
+ x = self.norm(x)
248
+ logits = self.lm_head(x)
249
+
250
+ loss = None
251
+ if labels is not None:
252
+ loss = F.cross_entropy(
253
+ logits[:, :-1].contiguous().view(-1, logits.size(-1)),
254
+ labels[:, 1:].contiguous().view(-1)
255
+ )
256
+
257
+ return CausalLMOutputWithPast(
258
+ loss=loss,
259
+ logits=logits,
260
+ past_key_values=past_key_values if use_cache else None
261
+ )
262
+
263
+ def prepare_inputs_for_generation(
264
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
265
+ ):
266
+ past_length = 0
267
+ if past_key_values is not None:
268
+ past_length = past_key_values.get_seq_length()
269
+ if past_length > 0:
270
+ input_ids = input_ids[:, -1:]
271
+
272
+ if inputs_embeds is not None and past_length == 0:
273
+ model_inputs = {"inputs_embeds": inputs_embeds}
274
+ else:
275
+ model_inputs = {"input_ids": input_ids}
276
+
277
+ model_inputs.update({
278
+ "past_key_values": past_key_values,
279
+ "use_cache": True,
280
+ "attention_mask": attention_mask,
281
+ })
282
+ return model_inputs
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tokenizer_class": "PreTrainedTokenizerFast",
3
+ "bos_token": "<|im_start|>",
4
+ "eos_token": "<|endoftext|>",
5
+ "unk_token": "<unk>",
6
+ "pad_token": "<|endoftext|>",
7
+ "add_bos_token": false,
8
+ "add_eos_token": false,
9
+ "model_max_length": 4096,
10
+ "clean_up_tokenization_spaces": false,
11
+ "added_tokens_decoder": {
12
+ "0": {
13
+ "content": "<unk>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "1": {
21
+ "content": "<|endoftext|>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "2": {
29
+ "content": "<|im_start|>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "3": {
37
+ "content": "<|im_end|>",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ }
44
+ }
45
+ }