appleeji commited on
Commit
ecef0fd
·
verified ·
1 Parent(s): dfd3147

Upload cnets.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. cnets.py +890 -0
cnets.py ADDED
@@ -0,0 +1,890 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch LLaMA model."""
21
+ import math
22
+ from typing import List, Optional, Tuple, Union
23
+ from collections import Counter
24
+ import torch
25
+ import torch.nn.functional as F
26
+ import torch.utils.checkpoint
27
+ from torch import nn
28
+ import os
29
+ from transformers.integrations.deepspeed import HfDeepSpeedConfig
30
+ from transformers.activations import ACT2FN
31
+ from transformers import AutoTokenizer
32
+ from modeling_llama_kv import LlamaForCausalLM
33
+ from modeling_qwen_kv import Qwen3ForCausalLM
34
+ from configs import EConfig
35
+ from safetensors import safe_open
36
+ from datasets import load_dataset
37
+ import multiprocessing
38
+
39
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
40
+ def _make_causal_mask(
41
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
42
+ ):
43
+ """
44
+ Make causal mask used for bi-directional self-attention.
45
+ """
46
+ bsz, tgt_len = input_ids_shape
47
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
48
+ mask_cond = torch.arange(mask.size(-1), device=device)
49
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
50
+ mask = mask.to(dtype)
51
+
52
+ if past_key_values_length > 0:
53
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
54
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
55
+
56
+
57
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
58
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
59
+ """
60
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
61
+ """
62
+ bsz, src_len = mask.size()
63
+ tgt_len = tgt_len if tgt_len is not None else src_len
64
+
65
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
66
+
67
+ inverted_mask = 1.0 - expanded_mask
68
+
69
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
70
+
71
+
72
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
73
+ """
74
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
75
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
76
+ """
77
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
78
+ if n_rep == 1:
79
+ return hidden_states
80
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
81
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
82
+
83
+
84
+ def rotate_half(x):
85
+ """Rotates half the hidden dims of the input."""
86
+ x1 = x[..., : x.shape[-1] // 2]
87
+ x2 = x[..., x.shape[-1] // 2:]
88
+ return torch.cat((-x2, x1), dim=-1)
89
+
90
+
91
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
92
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
93
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
94
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
95
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
96
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
97
+ q_embed = (q * cos) + (rotate_half(q) * sin)
98
+ k_embed = (k * cos) + (rotate_half(k) * sin)
99
+ return q_embed, k_embed
100
+
101
+
102
+ class LlamaRotaryEmbedding(torch.nn.Module):
103
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
104
+ super().__init__()
105
+
106
+ self.dim = dim
107
+ self.max_position_embeddings = max_position_embeddings
108
+ self.base = base
109
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
110
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
111
+
112
+ # Build here to make `torch.jit.trace` work.
113
+ self._set_cos_sin_cache(
114
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
115
+ )
116
+
117
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
118
+ self.max_seq_len_cached = seq_len
119
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
120
+
121
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
122
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
123
+ emb = torch.cat((freqs, freqs), dim=-1)
124
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
125
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
126
+
127
+ def forward(self, x, seq_len=None):
128
+ # x: [bs, num_attention_heads, seq_len, head_size]
129
+ if seq_len > self.max_seq_len_cached:
130
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
131
+
132
+ return (
133
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
134
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
135
+ )
136
+
137
+
138
+ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
139
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
140
+
141
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
142
+ self.scaling_factor = scaling_factor
143
+ super().__init__(dim, max_position_embeddings, base, device)
144
+
145
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
146
+ self.max_seq_len_cached = seq_len
147
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
148
+ t = t / self.scaling_factor
149
+
150
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
151
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
152
+ emb = torch.cat((freqs, freqs), dim=-1)
153
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
154
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
155
+
156
+
157
+ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
158
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
159
+
160
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
161
+ self.scaling_factor = scaling_factor
162
+ super().__init__(dim, max_position_embeddings, base, device)
163
+
164
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
165
+ self.max_seq_len_cached = seq_len
166
+
167
+ if seq_len > self.max_position_embeddings:
168
+ base = self.base * (
169
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
170
+ ) ** (self.dim / (self.dim - 2))
171
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
172
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
173
+
174
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
175
+
176
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
177
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
178
+ emb = torch.cat((freqs, freqs), dim=-1)
179
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
180
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
181
+
182
+
183
+
184
+ class LlamaAttention(nn.Module):
185
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
186
+
187
+ def __init__(self, config):
188
+ super().__init__()
189
+ self.config = config
190
+ self.hidden_size = config.hidden_size
191
+ self.num_heads = config.num_attention_heads
192
+ self.head_dim = self.hidden_size // self.num_heads
193
+ self.num_key_value_heads = config.num_key_value_heads
194
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
195
+ self.max_position_embeddings = config.max_position_embeddings
196
+
197
+ if (self.head_dim * self.num_heads) != self.hidden_size:
198
+ raise ValueError(
199
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
200
+ f" and `num_heads`: {self.num_heads})."
201
+ )
202
+ self.q_proj = nn.Linear(self.hidden_size * 2, self.num_heads * self.head_dim, bias=False)
203
+ self.k_proj = nn.Linear(self.hidden_size * 2, self.num_key_value_heads * self.head_dim, bias=False)
204
+ self.v_proj = nn.Linear(self.hidden_size * 2, self.num_key_value_heads * self.head_dim, bias=False)
205
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
206
+ self._init_rope()
207
+
208
+ def _init_rope(self):
209
+ if self.config.rope_scaling is None:
210
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
211
+ else:
212
+ scaling_type = self.config.rope_scaling["type"]
213
+ scaling_factor = self.config.rope_scaling["factor"]
214
+ if scaling_type == "linear":
215
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
216
+ self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
217
+ )
218
+ elif scaling_type == "dynamic":
219
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
220
+ self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
221
+ )
222
+ else:
223
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
224
+
225
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
226
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
227
+
228
+ def forward(
229
+ self,
230
+ hidden_states: torch.Tensor,
231
+ cache_hidden: Optional[List[torch.Tensor]] = None,
232
+ attention_mask: Optional[torch.Tensor] = None,
233
+ position_ids: Optional[torch.LongTensor] = None,
234
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
235
+ output_attentions: bool = False,
236
+ use_cache: bool = False,
237
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
238
+ bsz, q_len, _ = hidden_states.size()
239
+
240
+ query_states = self.q_proj(hidden_states)
241
+ key_states = self.k_proj(hidden_states)
242
+ value_states = self.v_proj(hidden_states)
243
+
244
+ lck = len(cache_hidden[0])
245
+
246
+ # cache_k = [self.k_proj(hidden) for hidden in cache_hidden]
247
+ # cache_v = [self.v_proj(hidden) for hidden in cache_hidden]
248
+
249
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
250
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
251
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
252
+
253
+
254
+ cos, sin = self.rotary_emb(query_states, seq_len=q_len + lck)
255
+ cos, sin = cos.to(query_states.device), sin.to(query_states.device)
256
+ # query_states = apply_rotary_pos_emb(query_states, cos, sin, position_ids)
257
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids + lck)
258
+
259
+
260
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
261
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
262
+
263
+ # Avoid modify hidden cache inplace which will cause in-place modification error when enable gradient checkpoint.
264
+ # Return the updated hidden cache instead.
265
+ if cache_hidden is None:
266
+ local_cache_k = []
267
+ local_cache_v = []
268
+ else:
269
+ local_cache_k = list(cache_hidden[0])
270
+ local_cache_v = list(cache_hidden[1])
271
+
272
+ local_cache_k.append(key_states)
273
+ local_cache_v.append(value_states)
274
+
275
+ cache_k = local_cache_k
276
+ cache_v = local_cache_v
277
+
278
+ k0 = cache_k[0]
279
+ v0 = cache_v[0]
280
+ lck = len(cache_k)
281
+
282
+ attn_weights = torch.matmul(query_states, k0.transpose(2, 3)) / math.sqrt(self.head_dim)
283
+ attn_weights = attn_weights + attention_mask
284
+
285
+ for i in range(1, lck):
286
+ ki = cache_k[i]
287
+
288
+ qi = query_states
289
+ kiq = ki
290
+
291
+ attn_weightsi = (qi * kiq).sum(-1) / math.sqrt(self.head_dim)
292
+ attn_weights = torch.cat((attn_weights, attn_weightsi[..., None]), dim=-1)
293
+
294
+ # upcast attention to fp32
295
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
296
+ attn_weights0 = attn_weights[..., :q_len]
297
+
298
+ attn_output = torch.matmul(attn_weights0, v0)
299
+
300
+ for i in range(1, lck):
301
+ vi = cache_v[i]
302
+ attn_weightsi = attn_weights[..., q_len + i - 1]
303
+ attn_outputi = attn_weightsi[..., None] * vi
304
+ attn_output = attn_output + attn_outputi
305
+
306
+ attn_output = attn_output.transpose(1, 2).contiguous()
307
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
308
+
309
+ attn_output = self.o_proj(attn_output)
310
+
311
+ # Return the updated hidden cache.
312
+ new_past_key_value = [local_cache_k,local_cache_v]
313
+ return attn_output, new_past_key_value
314
+
315
+
316
+ class LlamaMLP(nn.Module):
317
+ def __init__(self, config, last=True):
318
+ super().__init__()
319
+ self.last = last
320
+ self.config = config
321
+ self.hidden_size = config.hidden_size
322
+ self.intermediate_size = config.intermediate_size
323
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
324
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
325
+ # if last:
326
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
327
+ # else:
328
+ # self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size * 2, bias=False)
329
+ self.act_fn = ACT2FN[config.hidden_act]
330
+
331
+ def forward(self, x):
332
+ if self.config.pretraining_tp > 1:
333
+ slice = self.intermediate_size // self.config.pretraining_tp
334
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
335
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
336
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
337
+
338
+ gate_proj = torch.cat(
339
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
340
+ )
341
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
342
+
343
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
344
+ down_proj = [
345
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
346
+ ]
347
+ down_proj = sum(down_proj)
348
+ else:
349
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
350
+
351
+ return down_proj
352
+
353
+
354
+ class LlamaRMSNorm(nn.Module):
355
+ def __init__(self, hidden_size, eps=1e-6):
356
+ """
357
+ LlamaRMSNorm is equivalent to T5LayerNorm
358
+ """
359
+ super().__init__()
360
+ self.weight = nn.Parameter(torch.ones(hidden_size))
361
+ self.variance_epsilon = eps
362
+
363
+ def forward(self, hidden_states):
364
+ input_dtype = hidden_states.dtype
365
+ hidden_states = hidden_states.to(torch.float32)
366
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
367
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
368
+ return self.weight * hidden_states.to(input_dtype)
369
+
370
+
371
+ class LlamaDecoderLayeremb(nn.Module):
372
+ def __init__(self, config, last=True):
373
+ super().__init__()
374
+ self.hidden_size = config.hidden_size
375
+ self.self_attn = LlamaAttention(config=config)
376
+ self.mlp = LlamaMLP(config, last=last)
377
+ self.last = last
378
+ # self.fc = nn.Linear(config.hidden_size * 2, config.hidden_size)
379
+ self.hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
380
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
381
+ # if self.index!=0:
382
+
383
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
384
+
385
+ def forward(
386
+ self,
387
+ input_emb: torch.Tensor,
388
+ hidden_states: torch.Tensor,
389
+ cache_hidden: [List[torch.Tensor]] = [],
390
+ attention_mask: Optional[torch.Tensor] = None,
391
+ position_ids: Optional[torch.LongTensor] = None,
392
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
393
+ output_attentions: Optional[bool] = False,
394
+ use_cache: Optional[bool] = False,
395
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
396
+ """
397
+ Args:
398
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
399
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
400
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
401
+ output_attentions (`bool`, *optional*):
402
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
403
+ returned tensors for more detail.
404
+ use_cache (`bool`, *optional*):
405
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
406
+ (see `past_key_values`).
407
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
408
+ """
409
+
410
+ residual = hidden_states
411
+
412
+ hidden_states = self.hidden_norm(hidden_states)
413
+ input_emb = self.input_layernorm(input_emb)
414
+
415
+ hidden_states = torch.cat((input_emb, hidden_states), dim=-1)
416
+
417
+ return_hidden = hidden_states
418
+
419
+ # cache_hidden.append(hidden_states)
420
+
421
+ # Self Attention
422
+ hidden_states, latest_hidden_cache = self.self_attn(
423
+ cache_hidden=cache_hidden,
424
+ hidden_states=hidden_states,
425
+ attention_mask=attention_mask,
426
+ position_ids=position_ids,
427
+ past_key_value=past_key_value,
428
+ output_attentions=output_attentions,
429
+ use_cache=use_cache,
430
+ )
431
+ hidden_states = residual + hidden_states
432
+
433
+
434
+ residual = hidden_states
435
+
436
+ hidden_states = self.post_attention_layernorm(hidden_states)
437
+ hidden_states = self.mlp(hidden_states)
438
+ hidden_states = residual + hidden_states
439
+
440
+ outputs = (hidden_states, return_hidden)
441
+
442
+
443
+ return outputs, latest_hidden_cache
444
+
445
+
446
+ @torch.no_grad()
447
+ def padding(tensor, left=True):
448
+ zeropadding = torch.zeros_like(tensor[:, -1:])
449
+ if left:
450
+ tensor = torch.cat((zeropadding, tensor[:, :-1]), dim=1)
451
+ else:
452
+ tensor = torch.cat((tensor[:, 1:], zeropadding), dim=1)
453
+ return tensor
454
+
455
+
456
+ def process_data(data_chunk):
457
+
458
+ token_dict = Counter()
459
+ input_ids = data_chunk["input_ids"]
460
+ loss_mask = data_chunk["loss_mask"]
461
+ for i in range(len(input_ids)):
462
+ ids= input_ids[i][0]
463
+ mask = loss_mask[i][0]
464
+ for j in range(len(ids)):
465
+ if mask[j] == 1:
466
+ token_dict[ids[j]] += 1
467
+
468
+ return token_dict
469
+
470
+
471
+ def merge_dicts(dicts):
472
+ """合并多个 Counter 字典"""
473
+ result = Counter()
474
+ for d in dicts:
475
+ result.update(d)
476
+ return result
477
+
478
+
479
+ class Model(nn.Module):
480
+ def __init__(self, config, ds_config, training_config, load_head=False, load_emb=True, path=None, model_type='llama'):
481
+ super().__init__()
482
+ self.model_type = model_type
483
+ # self.layers = nn.ModuleList(
484
+ # [LlamaDecoderLayer(config, index=index) for index in range(config.num_hidden_layers)])
485
+ self.train_config = training_config
486
+ # Settng dschf to allow efficient ZeRO-3 usage between hf and ds.
487
+ if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3:
488
+ dschf = HfDeepSpeedConfig(ds_config)
489
+ else:
490
+ dschf = None
491
+ self.midlayer = LlamaDecoderLayeremb(config)
492
+ self.gradient_checkpointing = self.train_config["gradient_checkpointing"]
493
+ self.padding_idx = config.pad_token_id
494
+ self.vocab_size = config.vocab_size
495
+ self.hidden_size = config.hidden_size
496
+ self.draft_vocab_size = config.draft_vocab_size
497
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
498
+ self.length = 6 # Modified by ablation script
499
+
500
+ # Load target model based on model_type
501
+ if self.model_type == 'qwen3':
502
+ self.target_model = Qwen3ForCausalLM.from_pretrained(path, torch_dtype=torch.float16)
503
+ else: # default to llama
504
+ self.target_model = LlamaForCausalLM.from_pretrained(path, torch_dtype=torch.float16)
505
+
506
+ self.target_model.eval()
507
+ self.fc=nn.Linear(self.hidden_size*3, self.hidden_size, bias=False)
508
+ for param in self.target_model.parameters():
509
+ param.requires_grad = False
510
+
511
+ if not load_emb:
512
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
513
+
514
+ else:
515
+
516
+ from safetensors import safe_open
517
+ import json
518
+ import os
519
+ try:
520
+ with open(os.path.join(path, "model.safetensors.index.json"), "r") as f:
521
+ index_json = json.loads(f.read())
522
+ emb_path = index_json["weight_map"]["model.embed_tokens.weight"]
523
+ with safe_open(os.path.join(path, emb_path),
524
+ framework="pt",
525
+ device="cpu") as f:
526
+ tensor_slice = f.get_slice("model.embed_tokens.weight")
527
+ vocab_size, hidden_dim = tensor_slice.get_shape()
528
+ tensor = tensor_slice[:, :hidden_dim].float()
529
+ except:
530
+ with open(os.path.join(path, "pytorch_model.bin.index.json"), "r") as f:
531
+ index_json = json.loads(f.read())
532
+ emb_path = index_json["weight_map"]["model.embed_tokens.weight"]
533
+ weights = torch.load(os.path.join(path, emb_path))
534
+ tensor = weights["model.embed_tokens.weight"].float()
535
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, _weight=tensor)
536
+
537
+ self.lm_head = nn.Linear(config.hidden_size, config.draft_vocab_size, bias=False)
538
+
539
+ for param in self.embed_tokens.parameters():
540
+ param.requires_grad = False
541
+
542
+ def scandata(self, datapath, tokenizerpath):
543
+ N = self.draft_vocab_size
544
+
545
+ # [MODIFIED] Use different cache files for different model types
546
+ cache_file = f"cache_{self.model_type}.pt" if self.model_type != 'llama' else "cache.pt"
547
+
548
+ if not os.path.exists(cache_file):
549
+ tokenizer = AutoTokenizer.from_pretrained(tokenizerpath)
550
+ dataset = load_dataset('json', data_files=datapath)
551
+ dataset = dataset['train']
552
+ # dataset = dataset.select(range(96))
553
+ original_columns1 = dataset.column_names
554
+ num_proc = 48
555
+
556
+ # [MODIFIED] Set separators based on model type
557
+ if self.model_type == 'qwen3':
558
+ sep = "<|im_end|>\n<|im_start|>assistant\n"
559
+ sep2 = "<|im_end|>\n<|im_start|>user\n"
560
+ else: # llama
561
+ sep = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
562
+ sep2 = "<|eot_id|><|start_header_id|>user<|end_header_id|>"
563
+
564
+ def preprocess_function(examples):
565
+ new_examples = {
566
+ # "conversation": [],
567
+ "input_ids": [],
568
+ "loss_mask": []
569
+ }
570
+ for i in range(len(examples['id'])):
571
+ messages = [
572
+ {"role": "system",
573
+ "content": "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."},
574
+ ]
575
+ convroles = ["user", "assistant"]
576
+ roles = {"human": "user", "gpt": "assistant"}
577
+ source = examples['conversations'][i]
578
+ if not source:
579
+ continue
580
+ if roles[source[0]["from"]] != "user":
581
+ # Skip the first one if it is not from human
582
+ source = source[1:]
583
+ for j, sentence in enumerate(source):
584
+ role = roles[sentence["from"]]
585
+ assert role == convroles[j % 2], f"{i}"
586
+ # if sentence["from"]=="gpt":
587
+ # sentence["value"]=" "+sentence["value"]
588
+ messages.append(
589
+ {"role": role, "content": sentence["value"]}
590
+ )
591
+ conversation = tokenizer.apply_chat_template(
592
+ messages,
593
+ tokenize=False,
594
+ add_generation_prompt=False,
595
+ )
596
+
597
+ if not tokenizer.pad_token_id:
598
+ tokenizer.pad_token_id = tokenizer.unk_token_id
599
+
600
+ input_ids = tokenizer(
601
+ conversation,
602
+ return_tensors="pt",
603
+ add_special_tokens=False,
604
+ ).input_ids[0]
605
+ # When construct draft model vocab,
606
+ # filter out samples which is longer than max_len,
607
+ # instead of truncating them.
608
+ if len(input_ids) > self.train_config["max_len"]:
609
+ continue
610
+ loss_mask = torch.ones_like(input_ids)
611
+ # print(i)
612
+
613
+ total_len = len(input_ids)
614
+
615
+ turns = conversation.split(sep2)
616
+
617
+ # [MODIFIED] Skip samples with invalid conversation structure
618
+ if len(turns) < 2:
619
+ continue
620
+
621
+ turns[1] = turns[0] + sep2 + turns[1]
622
+ turns = turns[1:]
623
+
624
+ cur_len = 1
625
+ loss_mask[:cur_len] = 0
626
+ for i, turn in enumerate(turns):
627
+ if turn == "":
628
+ break
629
+ turn_len = len(tokenizer(turn).input_ids)
630
+
631
+ parts = turn.split(sep)
632
+ if len(parts) != 2:
633
+ break
634
+ parts[0] += sep
635
+ # "-2" is hardcoded for the Llama tokenizer to make the offset correct.
636
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 1
637
+
638
+ # Ignore the user instructions
639
+ if i == 0:
640
+ loss_mask[cur_len: cur_len + instruction_len - 2] = 0
641
+ else:
642
+ loss_mask[cur_len - 3: cur_len + instruction_len + 1] = 0
643
+ cur_len += turn_len
644
+ if i != 0:
645
+ cur_len += 3
646
+ # cur_len+=2
647
+
648
+ # if i != 0 and not tokenizer.legacy:
649
+ # # The legacy and non-legacy modes handle special tokens differently
650
+ # cur_len -= 1
651
+
652
+ loss_mask[cur_len:] = 0
653
+
654
+ # new_examples["conversation"].append(conversation)
655
+ new_examples["input_ids"].append(input_ids[None, :])
656
+ new_examples["loss_mask"].append(loss_mask[None, :])
657
+
658
+ return new_examples
659
+
660
+ dataset = dataset.map(
661
+ preprocess_function,
662
+ batched=True,
663
+ num_proc=num_proc,
664
+ remove_columns=original_columns1,
665
+ load_from_cache_file=False
666
+ )
667
+ #dataset.set_format(type="torch")
668
+
669
+
670
+
671
+ num_processes = num_proc
672
+ chunk_size = len(dataset) // num_processes + (len(dataset) % num_processes > 0)
673
+ chunks = [dataset[i:i + chunk_size] for i in range(0, len(dataset), chunk_size)]
674
+
675
+ # 创建进程池
676
+ with multiprocessing.Pool(num_processes) as pool:
677
+ # 并行处理数据块
678
+ results = pool.map(process_data, chunks)
679
+
680
+ # 合并结果
681
+ token_dict = merge_dicts(results)
682
+
683
+
684
+ total_frequency = sum(token_dict.values())
685
+ top_N = token_dict.most_common(N)
686
+ top_N_frequency_sum = sum(freq for key, freq in top_N)
687
+ top_N_ratio = top_N_frequency_sum / total_frequency
688
+ print(f"top {N} token frequency ratio: {top_N_ratio:.2%}")
689
+ used_tokens = [key for key, freq in top_N]
690
+ used_tokens.sort()
691
+ d2t = [used_tokens[i] - i for i in range(len(used_tokens))]
692
+ t2d = [i in used_tokens for i in range(self.vocab_size)]
693
+ d2t = torch.tensor(d2t)
694
+ t2d = torch.tensor(t2d)
695
+ cache = {
696
+ "d2t": d2t,
697
+ "t2d": t2d
698
+ }
699
+ torch.save(cache, cache_file)
700
+ else:
701
+ cache = torch.load(cache_file)
702
+ d2t = cache["d2t"]
703
+ t2d = cache["t2d"]
704
+ self.register_buffer("d2t", d2t)
705
+ self.register_buffer("t2d", t2d)
706
+ self.l1smooth = nn.SmoothL1Loss(reduction="none")
707
+
708
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
709
+ # create causal mask
710
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
711
+ combined_attention_mask = None
712
+ if input_shape[-1] > 1:
713
+ combined_attention_mask = _make_causal_mask(
714
+ input_shape,
715
+ inputs_embeds.dtype,
716
+ device=inputs_embeds.device,
717
+ past_key_values_length=past_key_values_length,
718
+ )
719
+
720
+ if attention_mask is not None:
721
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
722
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
723
+ inputs_embeds.device
724
+ )
725
+ combined_attention_mask = (
726
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
727
+ )
728
+
729
+ return combined_attention_mask
730
+
731
+ @torch.no_grad()
732
+ def dataprepare(self, input_ids, attention_mask, loss_mask):
733
+ device = input_ids.device
734
+ outs = self.target_model(input_ids=input_ids, attention_mask=attention_mask)
735
+ hidden_states0 = outs.hidden_states[0]
736
+ hidden_states1 = outs.hidden_states[1]
737
+ hidden_states2 = outs.hidden_states[2]
738
+ hidden_states=torch.cat((hidden_states0,hidden_states1,hidden_states2),dim=-1)
739
+ # hidden_states=torch.cat((hidden_states0,hidden_states1),dim=-1)
740
+ target = outs.logits
741
+ target = padding(target, left=False)
742
+ input_ids = padding(input_ids, left=False)
743
+
744
+ if target is not None:
745
+ target = target.to(device)
746
+ loss_mask = loss_mask[..., None]
747
+ loss_mask = loss_mask.to(device)
748
+
749
+ return hidden_states, target, loss_mask, input_ids
750
+
751
+ def forward(
752
+ self,
753
+ # hidden_states,
754
+ input_ids,
755
+ attention_mask: Optional[torch.Tensor] = None,
756
+ position_ids: Optional[torch.LongTensor] = None,
757
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
758
+ use_cache: Optional[bool] = None,
759
+ output_attentions: Optional[bool] = None,
760
+ output_hidden_states: Optional[bool] = None,
761
+ loss_mask: Optional[torch.Tensor] = None,
762
+
763
+ ):
764
+ hidden_states, target, loss_mask, input_ids = self.dataprepare(input_ids, attention_mask, loss_mask)
765
+
766
+ batch_size, seq_length, _ = hidden_states.shape
767
+ seq_length_with_past = seq_length
768
+ past_key_values_length = 0
769
+
770
+ # with torch.no_grad():
771
+ # inputs_embeds = self.embed_tokens(input_ids)
772
+ # inputs_embeds = inputs_embeds.detach()
773
+
774
+ if self.training and self.gradient_checkpointing and not hidden_states.requires_grad:
775
+ hidden_states.requires_grad = True
776
+
777
+ hidden_states=self.fc(hidden_states)
778
+
779
+ if past_key_values is not None:
780
+ past_key_values_length = past_key_values[0][0].shape[2]
781
+ seq_length_with_past = seq_length_with_past + past_key_values_length
782
+ if position_ids is None:
783
+ device = hidden_states.device
784
+ position_ids = torch.arange(
785
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
786
+ )
787
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
788
+ else:
789
+ position_ids = position_ids.view(-1, seq_length).long()
790
+
791
+ if attention_mask is None:
792
+ attention_mask = torch.ones(
793
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
794
+ )
795
+ attention_mask = self._prepare_decoder_attention_mask(
796
+ attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
797
+ )
798
+
799
+ if self.gradient_checkpointing and self.training:
800
+ if use_cache:
801
+ use_cache = False
802
+
803
+ plosses = []
804
+ vlosses = []
805
+ acces = []
806
+ cache_hidden = [[], []]
807
+
808
+ for idx in range(self.length):
809
+ last = idx == self.length - 1
810
+ inputs_embeds = self.embed_tokens(input_ids)
811
+ if self.training and self.gradient_checkpointing and not inputs_embeds.requires_grad:
812
+ inputs_embeds.requires_grad = True
813
+ inputs_embeds = inputs_embeds.to(hidden_states.dtype)
814
+
815
+ if self.gradient_checkpointing and self.training:
816
+
817
+ def create_custom_forward(module):
818
+ def custom_forward(*inputs):
819
+ # None for past_key_value
820
+ return module(*inputs, None, output_attentions)
821
+
822
+ return custom_forward
823
+
824
+ layer_outputs, cache_hidden = torch.utils.checkpoint.checkpoint(
825
+ create_custom_forward(self.midlayer),
826
+ inputs_embeds,
827
+ hidden_states,
828
+ cache_hidden,
829
+ attention_mask,
830
+ position_ids,
831
+ )
832
+ else:
833
+
834
+ layer_outputs, cache_hidden = self.midlayer(
835
+ input_emb=inputs_embeds,
836
+ hidden_states=hidden_states,
837
+ cache_hidden=cache_hidden,
838
+ attention_mask=attention_mask,
839
+ position_ids=position_ids,
840
+ past_key_value=None,
841
+ output_attentions=output_attentions,
842
+ use_cache=True,
843
+ )
844
+
845
+ hidden_states_out = layer_outputs[0]
846
+ # cache_hidden.append(layer_outputs[1])
847
+ # kv_cahce = layer_outputs[-1]
848
+
849
+ with torch.no_grad():
850
+ # hidden_states_target = padding(hidden_states, left=False)
851
+ target_head = target
852
+ target_max_token = target_head.argmax(-1)
853
+ # Move d2t to the same device as target_max_token
854
+ self.t2d = self.t2d.to(target_max_token.device)
855
+ target_mask = self.t2d[target_max_token]
856
+ target_mask = target_mask[..., None].int()
857
+ position_mask = target_mask * loss_mask
858
+ target_head = target_head[..., self.t2d]
859
+ target_head = target_head.float()
860
+ target_p = nn.Softmax(dim=2)(target_head)
861
+ target_p = target_p.detach()
862
+
863
+
864
+
865
+ hidden_states = hidden_states_out
866
+
867
+ hidden_states_out = self.norm(hidden_states_out)
868
+
869
+ logits = self.lm_head(hidden_states_out)
870
+ logits = logits.float()
871
+ out_logp = nn.LogSoftmax(dim=2)(logits)
872
+ plogp = target_p * out_logp
873
+ loss = -torch.sum(position_mask * plogp, 2).mean()
874
+ plosses.append(loss)
875
+ with torch.no_grad():
876
+ acces.append(((logits.argmax(-1) == target_p.argmax(-1)) * position_mask.squeeze(-1)).sum().item() / (
877
+ loss_mask.sum().item() + 1e-6))
878
+
879
+ if not last:
880
+ input_ids = padding(input_ids, left=False)
881
+ target = padding(target, left=False)
882
+ loss_mask = padding(loss_mask, left=False)
883
+
884
+
885
+
886
+ return plosses, vlosses, acces
887
+
888
+
889
+
890
+