SnifferCaptain commited on
Commit
48bbf86
·
verified ·
1 Parent(s): f8dc27b

Upload 7 files

Browse files
hf1.1s0/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "YForCausalLM1_1"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "ymodel1_1.YConfig1_1",
7
+ "AutoModelForCausalLM": "ymodel1_1.YForCausalLM1_1"
8
+ },
9
+ "bos_token_id": 1,
10
+ "dropout": 0.1,
11
+ "eos_token_id": 2,
12
+ "exp": 3.0,
13
+ "ffn_shared": 3,
14
+ "flash_attn": true,
15
+ "groups": 6,
16
+ "head_dim": 64,
17
+ "hidden_act": "gelu_pytorch_tanh",
18
+ "hidden_size": 512,
19
+ "intermediate_size": null,
20
+ "max_position_embeddings": 4096,
21
+ "model_type": "ynet",
22
+ "num_heads": 12,
23
+ "num_layers": 27,
24
+ "pe_dim": 96,
25
+ "rms_norm_eps": 1e-07,
26
+ "rope_theta": 50000.0,
27
+ "self_distill": true,
28
+ "torch_dtype": "bfloat16",
29
+ "transformers_version": "4.51.3",
30
+ "vocab_size": 6400
31
+ }
hf1.1s0/generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "transformers_version": "4.51.3"
6
+ }
hf1.1s0/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2915f75403c5ad05fff856884e57a28f10ab4343ac02b238418bbc874330e6b6
3
+ size 160612347
hf1.1s0/special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|im_start|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|im_end|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<|endoftext|>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
hf1.1s0/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
hf1.1s0/tokenizer_config.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<|endoftext|>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<|im_start|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "<|im_end|>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ }
30
+ },
31
+ "additional_special_tokens": [],
32
+ "bos_token": "<|im_start|>",
33
+ "chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<|im_start|>system\\n' + system_message + '<|im_end|>\\n' }}{% else %}{{ '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}",
34
+ "clean_up_tokenization_spaces": false,
35
+ "eos_token": "<|im_end|>",
36
+ "extra_special_tokens": {},
37
+ "legacy": true,
38
+ "model_max_length": 32768,
39
+ "pad_token": "<|endoftext|>",
40
+ "sp_model_kwargs": {},
41
+ "spaces_between_special_tokens": false,
42
+ "tokenizer_class": "PreTrainedTokenizer",
43
+ "unk_token": "<|endoftext|>"
44
+ }
hf1.1s0/ymodel1_1.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from typing import Optional, Tuple, Union, List
5
+ from transformers import PreTrainedModel, GenerationMixin
6
+ from transformers.activations import ACT2FN
7
+ from transformers.modeling_outputs import CausalLMOutputWithPast
8
+ from transformers.configuration_utils import PretrainedConfig
9
+
10
+ class YConfig1_1(PretrainedConfig):
11
+ model_type = "ynet"
12
+
13
+ def __init__(
14
+ self,
15
+ dropout: float = 0.1,
16
+ bos_token_id: int = 1,
17
+ eos_token_id: int = 2,
18
+ hidden_act: str = 'gelu_pytorch_tanh',
19
+ exp: float = 3.0,
20
+ ffn_shared: int = 3,
21
+ hidden_size: int = 512,
22
+ intermediate_size: int = None,
23
+ max_position_embeddings: int = 8192,
24
+ num_heads: int = 8,
25
+ num_layers: int = 9,
26
+ pe_dim: int = 64,
27
+ head_dim: int = 64,
28
+ groups: int = 4,
29
+ vocab_size: int = 6400,
30
+ rms_norm_eps: float = 1e-7,
31
+ rope_theta: int = 5e4,
32
+ flash_attn: bool = True,
33
+ self_distill: bool = True,
34
+ **kwargs
35
+ ):
36
+ super().__init__(**kwargs)
37
+ self.dropout = dropout
38
+ self.bos_token_id = bos_token_id
39
+ self.eos_token_id = eos_token_id
40
+ self.hidden_act = hidden_act
41
+ self.exp = exp # ffn 扩张倍率
42
+ self.ffn_shared = ffn_shared # ffn up & down权重共享层数
43
+ self.hidden_size = hidden_size
44
+ self.intermediate_size = intermediate_size
45
+ self.max_position_embeddings = max_position_embeddings
46
+ self.num_heads = num_heads # q头数
47
+ self.num_layers = num_layers # 层数
48
+ self.pe_dim = pe_dim # 位置嵌入头数
49
+ self.head_dim = head_dim # 头维度
50
+ self.groups = groups # GQA每个分组的头数
51
+ self.vocab_size = vocab_size
52
+ self.rms_norm_eps = rms_norm_eps
53
+ self.rope_theta = rope_theta
54
+ self.flash_attn = flash_attn
55
+ self.self_distill = self_distill
56
+
57
+ def scale_lvl(self, lvl:int=0):
58
+ if lvl == 0:
59
+ # normal settings [80.27m]
60
+ self.exp = 3.0
61
+ self.ffn_shared = 3
62
+ self.hidden_size = 512
63
+ self.num_heads = 12
64
+ self.num_layers = 27
65
+ self.pe_dim = 96
66
+ self.head_dim = 64
67
+ self.groups = 6
68
+ elif lvl == -1:
69
+ # small -1 [24m]
70
+ self.exp = 3.0
71
+ self.ffn_shared = 3
72
+ self.hidden_size = 512
73
+ self.num_heads = 8
74
+ self.num_layers = 12
75
+ self.pe_dim = 64
76
+ self.head_dim = 64
77
+ self.groups = 8
78
+ elif lvl == -2:
79
+ # small -2 [12m]
80
+ self.exp = 2.0
81
+ self.ffn_shared = 4
82
+ self.hidden_size = 512
83
+ self.num_heads = 7
84
+ self.num_layers = 8
85
+ self.pe_dim = 48
86
+ self.head_dim = 48
87
+ self.groups = 6
88
+ elif lvl == -3:
89
+ # small -3 [6m]
90
+ self.exp = 2.0
91
+ self.ffn_shared = 3
92
+ self.hidden_size = 384
93
+ self.num_heads = 7
94
+ self.num_layers = 6
95
+ self.pe_dim = 48
96
+ self.head_dim = 32
97
+ self.groups = 6
98
+ ######## large #######
99
+ elif lvl == 1:
100
+ # large +1 [0.2b]
101
+ self.exp = 2.0
102
+ self.ffn_shared = 3
103
+ self.hidden_size = 768
104
+ self.num_heads = 12
105
+ self.num_layers = 24
106
+ self.pe_dim = 96
107
+ self.head_dim = 64
108
+ self.groups = 6
109
+ elif lvl == 2:
110
+ # large +2 [0.6b]
111
+ self.exp = 3.0
112
+ self.ffn_shared = 3
113
+ self.hidden_size = 1344
114
+ self.num_heads = 25
115
+ self.num_layers = 24
116
+ self.pe_dim = 192
117
+ self.head_dim = 96
118
+ self.groups = 7
119
+ else:
120
+ raise ValueError(f"Invalid level: {lvl}")
121
+
122
+ class RMSNorm(torch.nn.Module):
123
+ def __init__(self, dim: int, eps: float = 1e-6):
124
+ super().__init__()
125
+ self.eps = eps
126
+ self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
127
+
128
+ def _norm(self, x):
129
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
130
+
131
+ def forward(self, x):
132
+ output = self._norm(x.float())
133
+ output = output * self.weight.float()
134
+ return output.type_as(x)
135
+
136
+ def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), theta: float = 5e4):
137
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
138
+ t = torch.arange(end, device=freqs.device)
139
+ freqs = torch.outer(t, freqs).float()
140
+ freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1)
141
+ freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1)
142
+ return freqs_cos, freqs_sin
143
+
144
+
145
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0):
146
+ def rotate_half(x):
147
+ return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1)
148
+
149
+ q_embed = (q * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(q) * sin.unsqueeze(unsqueeze_dim))
150
+ k_embed = (k * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(k) * sin.unsqueeze(unsqueeze_dim))
151
+ return q_embed, k_embed
152
+
153
+ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
154
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
155
+ b, h, l, ch = x.shape
156
+ if n_rep == 1:
157
+ return x
158
+ return (
159
+ x[:, :, None, :, :]
160
+ .expand(b, h, n_rep, l, ch)
161
+ .reshape(b, h * n_rep, l, ch)
162
+ )
163
+
164
+
165
+ class PEGA(nn.Module):
166
+ """
167
+ 位置编码门控注意力
168
+ """
169
+ def __init__(self, config: YConfig1_1):
170
+ super().__init__()
171
+ self.dropout = config.dropout # dropout rate
172
+ self.hidden_size = config.hidden_size # 输入通道大小
173
+ self.num_heads = config.num_heads # 总注意力头数
174
+ self.pe_dim = config.pe_dim # 位置嵌入维度数
175
+ self.head_dim = config.head_dim # 每个头的维度
176
+ self.groups = config.groups # GQA头数
177
+ self.hidden_kv_dim = int(self.head_dim * self.num_heads // self.groups)
178
+ self.gate_act = ACT2FN[config.hidden_act]
179
+ self.delta_kv_only = False
180
+
181
+ assert self.num_heads % self.groups == 0, "num_heads must be divisible by groups"
182
+
183
+ # self.qpe = nn.Linear(self.hidden_size, self.pe_dim, bias=False)
184
+ # self.kpe = nn.Linear(self.hidden_size, self.pe_dim, bias=False)
185
+ # self.q = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
186
+ # self.kv = nn.Linear(self.hidden_size, self.hidden_kv_dim, bias=False)
187
+ # equals to above
188
+ self.qkv_list = [self.pe_dim, self.pe_dim, self.num_heads * self.head_dim, self.hidden_kv_dim]
189
+ self.qkv = nn.Linear(self.hidden_size, sum(self.qkv_list), bias=False)
190
+ self.o = nn.Linear(self.num_heads * self.hidden_kv_dim, self.hidden_size, bias=False)
191
+ self.gate = nn.Linear(self.hidden_kv_dim, self.num_heads * self.hidden_kv_dim, bias=False)
192
+
193
+ self.rsqrt_dim = 1.0 / math.sqrt(self.head_dim)
194
+
195
+ def forward(
196
+ self,
197
+ x: torch.Tensor,
198
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
199
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
200
+ attention_mask: Optional[torch.Tensor] = None,
201
+ use_cache: bool = False,
202
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
203
+
204
+ b, l, _ = x.shape
205
+ cos, sin = position_embeddings # [L, head_dim]
206
+
207
+ # qpe = self.qpe(x) # [b, l, pe]
208
+ # kpe = self.kpe(x) # [b, l, pe]
209
+ # q = self.q(x) # [b, l, nope * hc]
210
+ # kv = self.kv(x) # [b, l, ckv]
211
+ qkv = self.qkv(x)
212
+ qpe, kpe, q, kv = torch.split(qkv, self.qkv_list, dim=-1)
213
+
214
+ # 应用 RoPE
215
+ qpe, kpe = apply_rotary_pos_emb(
216
+ qpe,
217
+ kpe,
218
+ cos[:l],
219
+ sin[:l],
220
+ )
221
+ deltakv = None
222
+ if self.delta_kv_only:
223
+ # 仅返回 delta kv
224
+ deltakv = (kpe, kv)
225
+
226
+ # kv_cache实现
227
+ if past_key_value is not None:
228
+ kpe = torch.cat([past_key_value[0], kpe], dim=1)
229
+ kv = torch.cat([past_key_value[1], kv], dim=1)
230
+ past_kv = (kpe, kv) if use_cache else None
231
+ _, l_all, _ = kv.shape
232
+
233
+ dropout_p = self.dropout if self.training else 0.0
234
+ attn_mask = None
235
+ if attention_mask is not None:
236
+ attn_mask = attention_mask.view(b, 1, 1, -1).expand(b, 1, l, -1)
237
+ attn_mask = attn_mask.bool() if attention_mask is not None else None
238
+
239
+ qpe = qpe.reshape(b, l, 1, self.pe_dim).permute(0, 2, 1, 3) # [b, pe, l, hc]
240
+ kpe = kpe.reshape(b, l_all, 1, self.pe_dim).permute(0, 2, 1, 3) # [b, pe, l_all, hc]
241
+ q = q.reshape(b, l, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # [b, nope, l, hc]
242
+ nopek = kv.reshape(b, l_all, self.num_heads // self.groups, self.head_dim).permute(0, 2, 1, 3) # [b, g, l_all, hc]
243
+ kv = kv.reshape(b, l_all, 1, self.hidden_kv_dim).permute(0, 2, 1, 3) # [b, 1, l_all, hc]
244
+
245
+ if self.training:
246
+ peo = nn.functional.scaled_dot_product_attention(
247
+ qpe, kpe, kv,
248
+ attn_mask=attn_mask, dropout_p=dropout_p if self.training else 0.0, is_causal=True
249
+ )
250
+ nopeo = nn.functional.scaled_dot_product_attention(
251
+ q, repeat_kv(nopek, self.groups), repeat_kv(kv, self.num_heads),
252
+ attn_mask=attn_mask, dropout_p=dropout_p if self.training else 0.0, is_causal=True
253
+ )
254
+ else:
255
+ # peo = nn.functional.scaled_dot_product_attention(
256
+ # qpe, kpe, kv,
257
+ # attn_mask=attn_mask, dropout_p=dropout_p if self.training else 0.0, is_causal=l != 1
258
+ # )
259
+ # nopeo = nn.functional.scaled_dot_product_attention(
260
+ # q, repeat_kv(nopek, self.groups), repeat_kv(kv, self.num_heads),
261
+ # attn_mask=attn_mask, dropout_p=dropout_p if self.training else 0.0, is_causal=l != 1
262
+ # )
263
+ peo = self.sdpa_math(qpe, kpe, kv, attn_mask, 0.0)
264
+ nopeo = self.sdpa_math(q, repeat_kv(nopek, self.groups), repeat_kv(kv, self.num_heads), attn_mask, 0.0)
265
+ peo = peo.permute(0, 2, 1, 3).reshape(b, l, -1)
266
+ nopeo = nopeo.permute(0, 2, 1, 3).reshape(b, l, -1)
267
+ gate = self.gate_act(self.gate(peo))
268
+ out = nopeo * gate
269
+ out = self.o(out)
270
+ out = nn.functional.dropout(out, p=self.dropout, training=self.training)
271
+ return out, (deltakv if self.delta_kv_only else past_kv)
272
+
273
+ def sdpa_math(self, q:torch.Tensor, k:torch.Tensor, v:torch.Tensor, attn_mask: Optional[torch.Tensor] = None,
274
+ dropout_p: float = 0.0) -> (torch.Tensor, torch.Tensor):
275
+ b, h, l, c = q.shape
276
+ scores = (q @ k.transpose(-2, -1)) * self.rsqrt_dim
277
+ casual_mask = torch.triu(
278
+ torch.full((l, l), float("-inf"), device=scores.device),
279
+ diagonal=1
280
+ ).unsqueeze(0).unsqueeze(0)# [1, 1, l, l]
281
+ # 在左侧 zero pad 到 scores 的形状 [1, 1, l, l_all]
282
+ casual_mask = nn.functional.pad(casual_mask, (scores.shape[-1] - l, 0), "constant", 0.0)# [1, 1, l, l_all]
283
+ scores += casual_mask
284
+
285
+ if attn_mask is not None:
286
+ attn_mask = (1.0 - attn_mask.type_as(scores)) * -1e9
287
+ scores = scores + attn_mask
288
+
289
+ scores = nn.functional.softmax(scores.float(), dim=-1).type_as(q)
290
+ scores = nn.functional.dropout(scores, p=dropout_p, training=self.training)# [b, h, l, l]
291
+ output = scores @ v
292
+ return output
293
+
294
+ def use_delta_kv_only(self, enable:bool=True):
295
+ # 仅返回 delta kv,减少内存开销
296
+ self.delta_kv_only = enable
297
+
298
+ class YFFN(nn.Module):
299
+ """
300
+ shared up & down GeGLU, LoE (Lack of Expert) arc
301
+ """
302
+ def __init__(self, config: YConfig1_1):
303
+ super().__init__()
304
+ self.act = ACT2FN[config.hidden_act]
305
+ self.channels = config.hidden_size
306
+ self.exp = config.exp
307
+ self.c_up = int(self.channels * self.exp)
308
+ self.ffn_shared = config.ffn_shared
309
+
310
+ self.up = nn.Linear(self.channels, self.c_up, bias=False)
311
+ self.down = nn.Linear(self.c_up, self.channels, bias=False)
312
+ self.gates = nn.ModuleList([
313
+ nn.Linear(self.channels, self.c_up, bias=False) for _ in range(self.ffn_shared)
314
+ ])
315
+
316
+ def forward(self, x:torch.Tensor, index:int, up_res:torch.Tensor=None) -> Tuple[torch.Tensor, torch.Tensor]:
317
+ up = self.up(x)
318
+ if up_res is not None:
319
+ up += up_res
320
+ gate = self.gates[index](x)
321
+ gate = self.act(gate)
322
+ up *= gate
323
+ x = self.down(up)
324
+ return x, up
325
+
326
+ class YBlock(nn.Module):
327
+ """
328
+ Groups of Transformer layers with shared FFN
329
+ num layers is ffn_shared
330
+ """
331
+ def __init__(self, config: YConfig1_1):
332
+ super().__init__()
333
+ self.attentions = nn.ModuleList([PEGA(config) for _ in range(config.ffn_shared)])
334
+ self.ffn = YFFN(config)
335
+ self.attn_norms = nn.ModuleList([
336
+ RMSNorm(config.hidden_size, eps=config.rms_norm_eps) for _ in range(config.ffn_shared)
337
+ ])
338
+ self.ffn_norms = nn.ModuleList([
339
+ RMSNorm(config.hidden_size, eps=config.rms_norm_eps) for _ in range(config.ffn_shared)
340
+ ])
341
+ self.use_self_distill = config.self_distill
342
+
343
+ def forward(self,
344
+ x: torch.Tensor,
345
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
346
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,# ffn_shard * kv cache
347
+ use_cache: bool = False,
348
+ attention_mask: Optional[torch.Tensor] = None
349
+ ):
350
+ b, l, _ = x.shape
351
+ kv_outs = []
352
+ ups = None
353
+ cos_loss = None
354
+ for i, (layer, kv_cache) in enumerate(zip(self.attentions, past_key_values)):
355
+ x0 = x
356
+ res = x
357
+ x = self.attn_norms[i](x)
358
+ x, kv_out = layer(
359
+ x = x,
360
+ position_embeddings=position_embeddings,
361
+ past_key_value=kv_cache,
362
+ attention_mask=attention_mask,
363
+ use_cache=use_cache
364
+ )
365
+ x += res
366
+ res = x
367
+ x = self.ffn_norms[i](x)
368
+ x, ups = self.ffn(x, i, ups)
369
+ x += res
370
+ kv_outs.append(kv_out)
371
+ if self.training and self.use_self_distill:
372
+ xd = x.detach()
373
+ # cosine loss
374
+ c_loss = 1.0 - nn.functional.cosine_similarity(x0, xd, dim=-1).mean()
375
+ cos_loss = c_loss + cos_loss if cos_loss is not None else c_loss
376
+ return x, kv_outs, cos_loss
377
+
378
+ def delta_kv_only(self, delta_kv:bool=True):
379
+ for i in range(len(self.attentions)):
380
+ self.attentions[i].use_delta_kv_only(delta_kv)
381
+
382
+
383
+ class YModel(nn.Module):
384
+ def __init__(self, config: YConfig1_1):
385
+ super().__init__()
386
+ self.vocab_size = config.vocab_size
387
+ self.num_layers = config.num_layers
388
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
389
+ self.dropout = config.dropout
390
+ self.ffn_shared = config.ffn_shared
391
+
392
+ assert self.num_layers % self.ffn_shared == 0, "num_layers must be divisible by ffn_shared"
393
+ self.blks = nn.ModuleList([
394
+ YBlock(config) for _ in range(self.num_layers // self.ffn_shared)
395
+ ])
396
+
397
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
398
+
399
+ freqs_cos, freqs_sin = precompute_freqs_cis(dim=config.pe_dim,
400
+ end=config.max_position_embeddings, theta=config.rope_theta)
401
+ self.register_buffer("freqs_cos", freqs_cos, persistent=False)
402
+ self.register_buffer("freqs_sin", freqs_sin, persistent=False)
403
+
404
+ def forward(self,
405
+ input_ids: Optional[torch.Tensor] = None,
406
+ attention_mask: Optional[torch.Tensor] = None,
407
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
408
+ use_cache: bool = False,
409
+ **kwargs
410
+ ):
411
+ batch_size, seq_length = input_ids.shape
412
+ past_key_values = past_key_values or [None] * self.num_layers
413
+ start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0
414
+
415
+ x = self.embed_tokens(input_ids)
416
+ x = nn.functional.dropout(x, p=self.dropout, training=self.training)
417
+
418
+ position_embeddings = (
419
+ self.freqs_cos[start_pos:start_pos + seq_length],
420
+ self.freqs_sin[start_pos:start_pos + seq_length]
421
+ )
422
+
423
+ presents = []
424
+ cos_loss = None
425
+ for layer_idx, block in enumerate(self.blks):
426
+ past_key_value = past_key_values[self.ffn_shared * layer_idx: self.ffn_shared * (layer_idx + 1)]
427
+ x, present, c_loss = block(
428
+ x = x,
429
+ position_embeddings = position_embeddings,
430
+ past_key_values=past_key_value,
431
+ use_cache=use_cache,
432
+ attention_mask=attention_mask
433
+ )
434
+ presents.extend(present)
435
+ cos_loss = c_loss + cos_loss if cos_loss is not None else c_loss
436
+
437
+ x = self.norm(x)
438
+ return x, presents, (cos_loss / self.num_layers if cos_loss is not None else None)
439
+
440
+ def delta_kv_only(self, delta_kv:bool=True):
441
+ for i in range(len(self.blks)):
442
+ self.blks[i].delta_kv_only(delta_kv)
443
+
444
+ class YForCausalLM1_1(PreTrainedModel, GenerationMixin):
445
+ config_class = YConfig1_1
446
+
447
+ def __init__(self, config: YConfig1_1 = None):
448
+ self.config = config or YConfig1_1()
449
+ super().__init__(self.config)
450
+ self.model = YModel(self.config)
451
+ self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
452
+ self.model.embed_tokens.weight = self.lm_head.weight
453
+ self.OUT = CausalLMOutputWithPast()
454
+
455
+ def forward(self,
456
+ input_ids: Optional[torch.Tensor] = None,
457
+ attention_mask: Optional[torch.Tensor] = None,
458
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
459
+ use_cache: bool = False,
460
+ logits_to_keep: Union[int, torch.Tensor] = 0,
461
+ **args):
462
+ h, past_kvs, cos_loss = self.model(
463
+ input_ids=input_ids,
464
+ attention_mask=attention_mask,
465
+ past_key_values=past_key_values,
466
+ use_cache=use_cache,
467
+ **args
468
+ )
469
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
470
+ logits = self.lm_head(h[:, slice_indices, :])
471
+ self.OUT.__setitem__('last_hidden_state', h)
472
+ self.OUT.__setitem__('logits', logits)
473
+ self.OUT.__setitem__('aux_loss', 0.0)
474
+ self.OUT.__setitem__('past_key_values', past_kvs)
475
+ self.OUT.__setitem__('dist_loss', cos_loss)
476
+ return self.OUT
477
+
478
+ def delta_kv_only(self, delta_kv:bool=True):
479
+ self.model.delta_kv_only(delta_kv)