yaro1214 commited on
Commit
c78bdd3
·
verified ·
1 Parent(s): 5528834

Upload folder using huggingface_hub

Browse files
dflash-hidden5-target5-block32/epoch_6_step_53334/config.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DFlashDraftModel"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoModel": "dflash.DFlashDraftModel"
9
+ },
10
+ "block_size": 16,
11
+ "bos_token_id": 151643,
12
+ "dflash_config": {
13
+ "mask_token_id": 151669,
14
+ "target_layer_ids": [
15
+ 1,
16
+ 9,
17
+ 17,
18
+ 25,
19
+ 33
20
+ ]
21
+ },
22
+ "dtype": "bfloat16",
23
+ "eos_token_id": 151645,
24
+ "head_dim": 128,
25
+ "hidden_act": "silu",
26
+ "hidden_size": 4096,
27
+ "initializer_range": 0.02,
28
+ "intermediate_size": 12288,
29
+ "layer_types": [
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention"
35
+ ],
36
+ "max_position_embeddings": 40960,
37
+ "max_window_layers": 5,
38
+ "model_type": "qwen3",
39
+ "num_attention_heads": 32,
40
+ "num_hidden_layers": 5,
41
+ "num_key_value_heads": 8,
42
+ "num_target_layers": 36,
43
+ "rms_norm_eps": 1e-06,
44
+ "rope_scaling": null,
45
+ "rope_theta": 1000000,
46
+ "sliding_window": null,
47
+ "tie_word_embeddings": false,
48
+ "transformers_version": "4.57.1",
49
+ "use_cache": true,
50
+ "use_sliding_window": false,
51
+ "vocab_size": 151936
52
+ }
dflash-hidden5-target5-block32/epoch_6_step_53334/dflash.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+ from transformers import DynamicCache
6
+ from transformers.cache_utils import Cache
7
+ from transformers.modeling_outputs import CausalLMOutputWithPast
8
+ from transformers.models.qwen3.modeling_qwen3 import (
9
+ ALL_ATTENTION_FUNCTIONS,
10
+ FlashAttentionKwargs,
11
+ GradientCheckpointingLayer,
12
+ Qwen3Config,
13
+ Qwen3MLP,
14
+ Qwen3PreTrainedModel,
15
+ Qwen3RMSNorm,
16
+ Qwen3RotaryEmbedding,
17
+ eager_attention_forward,
18
+ rotate_half,
19
+ )
20
+ from typing_extensions import Tuple, Unpack
21
+
22
+
23
+ def sample(logits: torch.Tensor, temperature: float = 0.0) -> torch.Tensor:
24
+ if temperature < 1e-5:
25
+ return torch.argmax(logits, dim=-1)
26
+ bsz, seq_len, vocab_size = logits.shape
27
+ logits = logits.view(-1, vocab_size)
28
+ logits = logits / temperature
29
+ probs = torch.softmax(logits, dim=-1)
30
+ return torch.multinomial(probs, num_samples=1).view(bsz, seq_len)
31
+
32
+
33
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
34
+ cos = cos.unsqueeze(unsqueeze_dim)
35
+ sin = sin.unsqueeze(unsqueeze_dim)
36
+ q_len = q.size(-2)
37
+ q_embed = (q * cos[..., -q_len:, :]) + (rotate_half(q) * sin[..., -q_len:, :])
38
+ k_embed = (k * cos) + (rotate_half(k) * sin)
39
+ return q_embed, k_embed
40
+
41
+
42
+ class Qwen3DFlashAttention(nn.Module):
43
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
44
+
45
+ def __init__(self, config: Qwen3Config, layer_idx: int):
46
+ super().__init__()
47
+ self.config = config
48
+ self.layer_idx = layer_idx
49
+ self.head_dim = getattr(
50
+ config, "head_dim", config.hidden_size // config.num_attention_heads
51
+ )
52
+ self.num_key_value_groups = (
53
+ config.num_attention_heads // config.num_key_value_heads
54
+ )
55
+ self.scaling = self.head_dim**-0.5
56
+ self.attention_dropout = config.attention_dropout
57
+ self.is_causal = False
58
+ self.q_proj = nn.Linear(
59
+ config.hidden_size,
60
+ config.num_attention_heads * self.head_dim,
61
+ bias=config.attention_bias,
62
+ )
63
+ self.k_proj = nn.Linear(
64
+ config.hidden_size,
65
+ config.num_key_value_heads * self.head_dim,
66
+ bias=config.attention_bias,
67
+ )
68
+ self.v_proj = nn.Linear(
69
+ config.hidden_size,
70
+ config.num_key_value_heads * self.head_dim,
71
+ bias=config.attention_bias,
72
+ )
73
+ self.o_proj = nn.Linear(
74
+ config.num_attention_heads * self.head_dim,
75
+ config.hidden_size,
76
+ bias=config.attention_bias,
77
+ )
78
+ self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps)
79
+ self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps)
80
+ self.sliding_window = (
81
+ config.sliding_window
82
+ if config.layer_types[layer_idx] == "sliding_attention"
83
+ else None
84
+ )
85
+
86
+ def forward(
87
+ self,
88
+ hidden_states: torch.Tensor,
89
+ target_hidden: torch.Tensor,
90
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
91
+ attention_mask: Optional[torch.Tensor],
92
+ past_key_values: Optional[Cache] = None,
93
+ cache_position: Optional[torch.LongTensor] = None,
94
+ **kwargs: Unpack[FlashAttentionKwargs],
95
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
96
+ bsz, q_len = hidden_states.shape[:-1]
97
+ ctx_len = target_hidden.shape[1]
98
+ q = self.q_proj(hidden_states)
99
+ q = q.view(bsz, q_len, -1, self.head_dim)
100
+ q = self.q_norm(q).transpose(1, 2)
101
+ k_ctx = self.k_proj(target_hidden)
102
+ k_noise = self.k_proj(hidden_states)
103
+ v_ctx = self.v_proj(target_hidden)
104
+ v_noise = self.v_proj(hidden_states)
105
+ k = torch.cat([k_ctx, k_noise], dim=1).view(
106
+ bsz, ctx_len + q_len, -1, self.head_dim
107
+ )
108
+ v = torch.cat([v_ctx, v_noise], dim=1).view(
109
+ bsz, ctx_len + q_len, -1, self.head_dim
110
+ )
111
+ k = self.k_norm(k).transpose(1, 2)
112
+ v = v.transpose(1, 2)
113
+ cos, sin = position_embeddings
114
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
115
+ if past_key_values is not None:
116
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
117
+ k, v = past_key_values.update(k, v, self.layer_idx, cache_kwargs)
118
+ attn_fn: Callable = eager_attention_forward
119
+ if self.config._attn_implementation != "eager":
120
+ attn_fn = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
121
+ attn_output, attn_weights = attn_fn(
122
+ self,
123
+ q,
124
+ k,
125
+ v,
126
+ attention_mask,
127
+ dropout=0.0 if not self.training else self.attention_dropout,
128
+ scaling=self.scaling,
129
+ sliding_window=self.sliding_window,
130
+ **kwargs,
131
+ )
132
+ attn_output = attn_output.reshape(bsz, q_len, -1)
133
+ attn_output = self.o_proj(attn_output)
134
+ return attn_output, attn_weights
135
+
136
+
137
+ class Qwen3DFlashDecoderLayer(GradientCheckpointingLayer):
138
+ def __init__(self, config: Qwen3Config, layer_idx: int):
139
+ super().__init__()
140
+ self.hidden_size = config.hidden_size
141
+ self.self_attn = Qwen3DFlashAttention(config=config, layer_idx=layer_idx)
142
+ self.mlp = Qwen3MLP(config)
143
+ self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
144
+ self.post_attention_layernorm = Qwen3RMSNorm(
145
+ config.hidden_size, eps=config.rms_norm_eps
146
+ )
147
+
148
+ def forward(
149
+ self,
150
+ target_hidden: Optional[torch.Tensor] = None,
151
+ hidden_states: Optional[torch.Tensor] = None,
152
+ attention_mask: Optional[torch.Tensor] = None,
153
+ position_ids: Optional[torch.LongTensor] = None,
154
+ past_key_value: Optional[Cache] = None,
155
+ output_attentions: Optional[bool] = False,
156
+ use_cache: Optional[bool] = False,
157
+ cache_position: Optional[torch.LongTensor] = None,
158
+ position_embeddings: Optional[
159
+ Tuple[torch.Tensor, torch.Tensor]
160
+ ] = None, # necessary, but kept here for BC
161
+ **kwargs: Unpack[FlashAttentionKwargs],
162
+ ) -> Tuple[
163
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
164
+ ]:
165
+ residual = hidden_states
166
+ hidden_states = self.input_layernorm(hidden_states)
167
+ hidden_states = self.self_attn(
168
+ hidden_states=hidden_states,
169
+ target_hidden=target_hidden,
170
+ attention_mask=attention_mask,
171
+ position_ids=position_ids,
172
+ past_key_values=past_key_value,
173
+ output_attentions=output_attentions,
174
+ use_cache=use_cache,
175
+ cache_position=cache_position,
176
+ position_embeddings=position_embeddings,
177
+ **kwargs,
178
+ )[0]
179
+ hidden_states = residual + hidden_states
180
+ residual = hidden_states
181
+ hidden_states = self.post_attention_layernorm(hidden_states)
182
+ hidden_states = self.mlp(hidden_states)
183
+ hidden_states = residual + hidden_states
184
+ return hidden_states
185
+
186
+
187
+ def build_target_layer_ids(num_target_layers: int, num_draft_layers: int):
188
+ if num_draft_layers == 1:
189
+ return [(num_target_layers // 2)]
190
+ start = 1
191
+ end = num_target_layers - 3
192
+ span = end - start
193
+ target_layer_ids = [
194
+ int(round(start + (i * span) / (num_draft_layers - 1)))
195
+ for i in range(num_draft_layers)
196
+ ]
197
+ return target_layer_ids
198
+
199
+
200
+ def extract_context_feature(
201
+ hidden_states: list[torch.Tensor],
202
+ layer_ids: Optional[list[int]],
203
+ ) -> torch.Tensor:
204
+ offset = 1
205
+ selected_states = []
206
+ for layer_id in layer_ids:
207
+ selected_states.append(hidden_states[layer_id + offset])
208
+ target_hidden = torch.cat(selected_states, dim=-1)
209
+ return target_hidden
210
+
211
+
212
+ class DFlashDraftModel(Qwen3PreTrainedModel):
213
+ config_class = Qwen3Config
214
+ _no_split_modules = ["Qwen3DFlashDecoderLayer"]
215
+
216
+ def __init__(self, config) -> None:
217
+ super().__init__(config)
218
+ self.config = config
219
+ self.layers = nn.ModuleList(
220
+ [
221
+ Qwen3DFlashDecoderLayer(config, layer_idx)
222
+ for layer_idx in range(config.num_hidden_layers)
223
+ ]
224
+ )
225
+ dflash_config = getattr(config, "dflash_config", {}) or {}
226
+ self.target_layer_ids = dflash_config.get(
227
+ "target_layer_ids",
228
+ build_target_layer_ids(config.num_target_layers, config.num_hidden_layers),
229
+ )
230
+ self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
231
+ self.rotary_emb = Qwen3RotaryEmbedding(config)
232
+ self.fc = nn.Linear(
233
+ len(self.target_layer_ids) * config.hidden_size,
234
+ config.hidden_size,
235
+ bias=False,
236
+ )
237
+ self.hidden_norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
238
+ self.block_size = config.block_size
239
+ self.mask_token_id = dflash_config.get("mask_token_id", None)
240
+ self.post_init()
241
+
242
+ def forward(
243
+ self,
244
+ position_ids: torch.LongTensor,
245
+ attention_mask: Optional[torch.Tensor] = None,
246
+ noise_embedding: Optional[torch.Tensor] = None,
247
+ target_hidden: Optional[torch.Tensor] = None,
248
+ past_key_values: Optional[Cache] = None,
249
+ use_cache: bool = False,
250
+ **kwargs,
251
+ ) -> CausalLMOutputWithPast:
252
+ hidden_states = noise_embedding
253
+ target_hidden = self.hidden_norm(self.fc(target_hidden))
254
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
255
+ for layer in self.layers:
256
+ hidden_states = layer(
257
+ hidden_states=hidden_states,
258
+ target_hidden=target_hidden,
259
+ attention_mask=attention_mask,
260
+ position_ids=position_ids,
261
+ past_key_value=past_key_values,
262
+ use_cache=use_cache,
263
+ position_embeddings=position_embeddings,
264
+ **kwargs,
265
+ )
266
+ return self.norm(hidden_states)
267
+
268
+ @torch.inference_mode()
269
+ def spec_generate(
270
+ self,
271
+ target: nn.Module,
272
+ input_ids: torch.LongTensor,
273
+ max_new_tokens: int,
274
+ stop_token_ids: list[int],
275
+ temperature: float,
276
+ ):
277
+ self.eval()
278
+ num_input_tokens = input_ids.shape[1]
279
+ max_length = num_input_tokens + max_new_tokens
280
+
281
+ block_size = self.block_size
282
+ output_ids = torch.full(
283
+ (1, max_length + block_size),
284
+ self.mask_token_id,
285
+ dtype=torch.long,
286
+ device=target.device,
287
+ )
288
+ position_ids = torch.arange(
289
+ output_ids.shape[1], device=target.device
290
+ ).unsqueeze(0)
291
+
292
+ past_key_values_target = DynamicCache()
293
+ past_key_values_draft = DynamicCache()
294
+
295
+ # Prefill stage
296
+ output = target(
297
+ input_ids,
298
+ position_ids=position_ids[:, :num_input_tokens],
299
+ past_key_values=past_key_values_target,
300
+ use_cache=True,
301
+ logits_to_keep=1,
302
+ output_hidden_states=True,
303
+ )
304
+
305
+ output_ids[:, :num_input_tokens] = input_ids
306
+ output_ids[:, num_input_tokens : num_input_tokens + 1] = sample(
307
+ output.logits, temperature
308
+ )
309
+ target_hidden = extract_context_feature(
310
+ output.hidden_states, self.target_layer_ids
311
+ )
312
+
313
+ # Decode stage
314
+ acceptance_lengths = []
315
+ start = input_ids.shape[1]
316
+ while start < max_length:
317
+ block_output_ids = output_ids[:, start : start + block_size].clone()
318
+ block_position_ids = position_ids[:, start : start + block_size]
319
+ noise_embedding = target.model.embed_tokens(block_output_ids)
320
+ draft_logits = target.lm_head(
321
+ self(
322
+ target_hidden=target_hidden,
323
+ noise_embedding=noise_embedding,
324
+ position_ids=position_ids[
325
+ :, past_key_values_draft.get_seq_length() : start + block_size
326
+ ],
327
+ past_key_values=past_key_values_draft,
328
+ use_cache=True,
329
+ is_causal=False,
330
+ )[:, -block_size + 1 :, :]
331
+ )
332
+ past_key_values_draft.crop(start)
333
+ block_output_ids[:, 1:] = sample(draft_logits)
334
+
335
+ output = target(
336
+ block_output_ids,
337
+ position_ids=block_position_ids,
338
+ past_key_values=past_key_values_target,
339
+ use_cache=True,
340
+ output_hidden_states=True,
341
+ )
342
+
343
+ posterior = sample(output.logits, temperature)
344
+ acceptance_length = (
345
+ (block_output_ids[:, 1:] == posterior[:, :-1])
346
+ .cumprod(dim=1)
347
+ .sum(dim=1)[0]
348
+ .item()
349
+ )
350
+ output_ids[:, start : start + acceptance_length + 1] = block_output_ids[
351
+ :, : acceptance_length + 1
352
+ ]
353
+ output_ids[:, start + acceptance_length + 1] = posterior[
354
+ :, acceptance_length
355
+ ]
356
+ start += acceptance_length + 1
357
+ past_key_values_target.crop(start)
358
+ target_hidden = extract_context_feature(
359
+ output.hidden_states, self.target_layer_ids
360
+ )[:, : acceptance_length + 1, :]
361
+ acceptance_lengths.append(acceptance_length + 1)
362
+ if stop_token_ids is not None and any(
363
+ stop_token_id in output_ids[:, num_input_tokens:]
364
+ for stop_token_id in stop_token_ids
365
+ ):
366
+ break
367
+ output_ids = output_ids[:, :max_length]
368
+ output_ids = output_ids[:, output_ids[0] != self.mask_token_id]
369
+ if stop_token_ids is not None:
370
+ stop_token_ids = torch.tensor(stop_token_ids, device=output_ids.device)
371
+ stop_token_indices = torch.isin(
372
+ output_ids[0][num_input_tokens:], stop_token_ids
373
+ ).nonzero(as_tuple=True)[0]
374
+ if stop_token_indices.numel() > 0:
375
+ output_ids = output_ids[
376
+ :, : num_input_tokens + stop_token_indices[0] + 1
377
+ ]
378
+
379
+ return output_ids
dflash-hidden5-target5-block32/epoch_6_step_53334/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62024dbf48ba504902dbf52ffe45894188c200438911ff77db62289784029e5d
3
+ size 2097259104
dflash-hidden5-target5-block32/epoch_6_step_53334/training_state.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff586bce67a8b831804023d763f867e37c9df7e60395e98eb948fbee9d78faa6
3
+ size 2293305969
edit-dflash-hidden5-target5-block16-edit-hidden5/epoch_1_step_55000/config.json ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DFlashDraftModel"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "block_size": 16,
8
+ "bos_token_id": 151643,
9
+ "dflash_config": {
10
+ "mask_token_id": 151669,
11
+ "target_layer_ids": [
12
+ 1,
13
+ 9,
14
+ 17,
15
+ 25,
16
+ 33
17
+ ]
18
+ },
19
+ "dtype": "bfloat16",
20
+ "eos_token_id": 151645,
21
+ "head_dim": 128,
22
+ "hidden_act": "silu",
23
+ "hidden_size": 4096,
24
+ "initializer_range": 0.02,
25
+ "intermediate_size": 12288,
26
+ "layer_types": [
27
+ "full_attention",
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention",
40
+ "full_attention",
41
+ "full_attention",
42
+ "full_attention",
43
+ "full_attention",
44
+ "full_attention",
45
+ "full_attention",
46
+ "full_attention",
47
+ "full_attention",
48
+ "full_attention",
49
+ "full_attention",
50
+ "full_attention",
51
+ "full_attention",
52
+ "full_attention",
53
+ "full_attention",
54
+ "full_attention",
55
+ "full_attention",
56
+ "full_attention",
57
+ "full_attention",
58
+ "full_attention",
59
+ "full_attention",
60
+ "full_attention",
61
+ "full_attention",
62
+ "full_attention"
63
+ ],
64
+ "max_position_embeddings": 40960,
65
+ "max_window_layers": 36,
66
+ "model_type": "qwen3",
67
+ "num_attention_heads": 32,
68
+ "num_hidden_layers": 5,
69
+ "num_key_value_heads": 8,
70
+ "num_target_capture_layers": 5,
71
+ "num_target_layers": 36,
72
+ "rms_norm_eps": 1e-06,
73
+ "rope_scaling": null,
74
+ "rope_theta": 1000000,
75
+ "sliding_window": null,
76
+ "target_layer_ids": [
77
+ 1,
78
+ 9,
79
+ 17,
80
+ 25,
81
+ 33
82
+ ],
83
+ "tie_word_embeddings": false,
84
+ "transformers_version": "4.57.1",
85
+ "use_cache": true,
86
+ "use_sliding_window": false,
87
+ "vocab_size": 151936
88
+ }
edit-dflash-hidden5-target5-block16-edit-hidden5/epoch_1_step_55000/dflash.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+ from transformers import DynamicCache
6
+ from transformers.cache_utils import Cache
7
+ from transformers.modeling_outputs import CausalLMOutputWithPast
8
+ from transformers.models.qwen3.modeling_qwen3 import (
9
+ ALL_ATTENTION_FUNCTIONS,
10
+ FlashAttentionKwargs,
11
+ GradientCheckpointingLayer,
12
+ Qwen3Config,
13
+ Qwen3MLP,
14
+ Qwen3PreTrainedModel,
15
+ Qwen3RMSNorm,
16
+ Qwen3RotaryEmbedding,
17
+ eager_attention_forward,
18
+ rotate_half,
19
+ )
20
+ from typing_extensions import Tuple, Unpack
21
+
22
+
23
+ def sample(logits: torch.Tensor, temperature: float = 0.0) -> torch.Tensor:
24
+ if temperature < 1e-5:
25
+ return torch.argmax(logits, dim=-1)
26
+ bsz, seq_len, vocab_size = logits.shape
27
+ logits = logits.view(-1, vocab_size)
28
+ logits = logits / temperature
29
+ probs = torch.softmax(logits, dim=-1)
30
+ return torch.multinomial(probs, num_samples=1).view(bsz, seq_len)
31
+
32
+
33
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
34
+ cos = cos.unsqueeze(unsqueeze_dim)
35
+ sin = sin.unsqueeze(unsqueeze_dim)
36
+ q_len = q.size(-2)
37
+ q_embed = (q * cos[..., -q_len:, :]) + (rotate_half(q) * sin[..., -q_len:, :])
38
+ k_embed = (k * cos) + (rotate_half(k) * sin)
39
+ return q_embed, k_embed
40
+
41
+
42
+ class Qwen3DFlashAttention(nn.Module):
43
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
44
+
45
+ def __init__(self, config: Qwen3Config, layer_idx: int):
46
+ super().__init__()
47
+ self.config = config
48
+ self.layer_idx = layer_idx
49
+ self.head_dim = getattr(
50
+ config, "head_dim", config.hidden_size // config.num_attention_heads
51
+ )
52
+ self.num_key_value_groups = (
53
+ config.num_attention_heads // config.num_key_value_heads
54
+ )
55
+ self.scaling = self.head_dim**-0.5
56
+ self.attention_dropout = config.attention_dropout
57
+ self.is_causal = False
58
+ self.q_proj = nn.Linear(
59
+ config.hidden_size,
60
+ config.num_attention_heads * self.head_dim,
61
+ bias=config.attention_bias,
62
+ )
63
+ self.k_proj = nn.Linear(
64
+ config.hidden_size,
65
+ config.num_key_value_heads * self.head_dim,
66
+ bias=config.attention_bias,
67
+ )
68
+ self.v_proj = nn.Linear(
69
+ config.hidden_size,
70
+ config.num_key_value_heads * self.head_dim,
71
+ bias=config.attention_bias,
72
+ )
73
+ self.o_proj = nn.Linear(
74
+ config.num_attention_heads * self.head_dim,
75
+ config.hidden_size,
76
+ bias=config.attention_bias,
77
+ )
78
+ self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps)
79
+ self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps)
80
+ self.sliding_window = (
81
+ config.sliding_window
82
+ if config.layer_types[layer_idx] == "sliding_attention"
83
+ else None
84
+ )
85
+
86
+ def forward(
87
+ self,
88
+ hidden_states: torch.Tensor,
89
+ target_hidden: torch.Tensor,
90
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
91
+ attention_mask: Optional[torch.Tensor],
92
+ past_key_values: Optional[Cache] = None,
93
+ cache_position: Optional[torch.LongTensor] = None,
94
+ **kwargs: Unpack[FlashAttentionKwargs],
95
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
96
+ bsz, q_len = hidden_states.shape[:-1]
97
+ ctx_len = target_hidden.shape[1]
98
+ q = self.q_proj(hidden_states)
99
+ q = q.view(bsz, q_len, -1, self.head_dim)
100
+ q = self.q_norm(q).transpose(1, 2)
101
+ k_ctx = self.k_proj(target_hidden)
102
+ k_noise = self.k_proj(hidden_states)
103
+ v_ctx = self.v_proj(target_hidden)
104
+ v_noise = self.v_proj(hidden_states)
105
+ k = torch.cat([k_ctx, k_noise], dim=1).view(
106
+ bsz, ctx_len + q_len, -1, self.head_dim
107
+ )
108
+ v = torch.cat([v_ctx, v_noise], dim=1).view(
109
+ bsz, ctx_len + q_len, -1, self.head_dim
110
+ )
111
+ k = self.k_norm(k).transpose(1, 2)
112
+ v = v.transpose(1, 2)
113
+ cos, sin = position_embeddings
114
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
115
+ if past_key_values is not None:
116
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
117
+ k, v = past_key_values.update(k, v, self.layer_idx, cache_kwargs)
118
+ attn_fn: Callable = eager_attention_forward
119
+ if self.config._attn_implementation != "eager":
120
+ attn_fn = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
121
+ attn_output, attn_weights = attn_fn(
122
+ self,
123
+ q,
124
+ k,
125
+ v,
126
+ attention_mask,
127
+ dropout=0.0 if not self.training else self.attention_dropout,
128
+ scaling=self.scaling,
129
+ sliding_window=self.sliding_window,
130
+ **kwargs,
131
+ )
132
+ attn_output = attn_output.reshape(bsz, q_len, -1)
133
+ attn_output = self.o_proj(attn_output)
134
+ return attn_output, attn_weights
135
+
136
+
137
+ class Qwen3DFlashDecoderLayer(GradientCheckpointingLayer):
138
+ def __init__(self, config: Qwen3Config, layer_idx: int):
139
+ super().__init__()
140
+ self.hidden_size = config.hidden_size
141
+ self.self_attn = Qwen3DFlashAttention(config=config, layer_idx=layer_idx)
142
+ self.mlp = Qwen3MLP(config)
143
+ self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
144
+ self.post_attention_layernorm = Qwen3RMSNorm(
145
+ config.hidden_size, eps=config.rms_norm_eps
146
+ )
147
+
148
+ def forward(
149
+ self,
150
+ target_hidden: Optional[torch.Tensor] = None,
151
+ hidden_states: Optional[torch.Tensor] = None,
152
+ attention_mask: Optional[torch.Tensor] = None,
153
+ position_ids: Optional[torch.LongTensor] = None,
154
+ past_key_value: Optional[Cache] = None,
155
+ output_attentions: Optional[bool] = False,
156
+ use_cache: Optional[bool] = False,
157
+ cache_position: Optional[torch.LongTensor] = None,
158
+ position_embeddings: Optional[
159
+ Tuple[torch.Tensor, torch.Tensor]
160
+ ] = None, # necessary, but kept here for BC
161
+ **kwargs: Unpack[FlashAttentionKwargs],
162
+ ) -> Tuple[
163
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
164
+ ]:
165
+ residual = hidden_states
166
+ hidden_states = self.input_layernorm(hidden_states)
167
+ hidden_states = self.self_attn(
168
+ hidden_states=hidden_states,
169
+ target_hidden=target_hidden,
170
+ attention_mask=attention_mask,
171
+ position_ids=position_ids,
172
+ past_key_values=past_key_value,
173
+ output_attentions=output_attentions,
174
+ use_cache=use_cache,
175
+ cache_position=cache_position,
176
+ position_embeddings=position_embeddings,
177
+ **kwargs,
178
+ )[0]
179
+ hidden_states = residual + hidden_states
180
+ residual = hidden_states
181
+ hidden_states = self.post_attention_layernorm(hidden_states)
182
+ hidden_states = self.mlp(hidden_states)
183
+ hidden_states = residual + hidden_states
184
+ return hidden_states
185
+
186
+
187
+ def build_target_layer_ids(num_target_layers: int, num_draft_layers: int):
188
+ if num_draft_layers == 1:
189
+ return [(num_target_layers // 2)]
190
+ start = 1
191
+ end = num_target_layers - 3
192
+ span = end - start
193
+ target_layer_ids = [
194
+ int(round(start + (i * span) / (num_draft_layers - 1)))
195
+ for i in range(num_draft_layers)
196
+ ]
197
+ return target_layer_ids
198
+
199
+
200
+ def extract_context_feature(
201
+ hidden_states: list[torch.Tensor],
202
+ layer_ids: Optional[list[int]],
203
+ ) -> torch.Tensor:
204
+ offset = 1
205
+ selected_states = []
206
+ for layer_id in layer_ids:
207
+ selected_states.append(hidden_states[layer_id + offset])
208
+ target_hidden = torch.cat(selected_states, dim=-1)
209
+ return target_hidden
210
+
211
+
212
+ class DFlashDraftModel(Qwen3PreTrainedModel):
213
+ config_class = Qwen3Config
214
+ _no_split_modules = ["Qwen3DFlashDecoderLayer"]
215
+
216
+ def __init__(self, config) -> None:
217
+ super().__init__(config)
218
+ self.config = config
219
+ self.layers = nn.ModuleList(
220
+ [
221
+ Qwen3DFlashDecoderLayer(config, layer_idx)
222
+ for layer_idx in range(config.num_hidden_layers)
223
+ ]
224
+ )
225
+ dflash_config = getattr(config, "dflash_config", {}) or {}
226
+ self.target_layer_ids = dflash_config.get(
227
+ "target_layer_ids",
228
+ build_target_layer_ids(config.num_target_layers, config.num_hidden_layers),
229
+ )
230
+ self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
231
+ self.rotary_emb = Qwen3RotaryEmbedding(config)
232
+ self.fc = nn.Linear(
233
+ len(self.target_layer_ids) * config.hidden_size,
234
+ config.hidden_size,
235
+ bias=False,
236
+ )
237
+ self.hidden_norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
238
+ self.block_size = config.block_size
239
+ self.mask_token_id = dflash_config.get("mask_token_id", None)
240
+ self.post_init()
241
+
242
+ def forward(
243
+ self,
244
+ position_ids: torch.LongTensor,
245
+ attention_mask: Optional[torch.Tensor] = None,
246
+ noise_embedding: Optional[torch.Tensor] = None,
247
+ target_hidden: Optional[torch.Tensor] = None,
248
+ past_key_values: Optional[Cache] = None,
249
+ use_cache: bool = False,
250
+ **kwargs,
251
+ ) -> CausalLMOutputWithPast:
252
+ hidden_states = noise_embedding
253
+ target_hidden = self.hidden_norm(self.fc(target_hidden))
254
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
255
+ for layer in self.layers:
256
+ hidden_states = layer(
257
+ hidden_states=hidden_states,
258
+ target_hidden=target_hidden,
259
+ attention_mask=attention_mask,
260
+ position_ids=position_ids,
261
+ past_key_value=past_key_values,
262
+ use_cache=use_cache,
263
+ position_embeddings=position_embeddings,
264
+ **kwargs,
265
+ )
266
+ return self.norm(hidden_states)
267
+
268
+ @torch.inference_mode()
269
+ def spec_generate(
270
+ self,
271
+ target: nn.Module,
272
+ input_ids: torch.LongTensor,
273
+ max_new_tokens: int,
274
+ stop_token_ids: list[int],
275
+ temperature: float,
276
+ ):
277
+ self.eval()
278
+ num_input_tokens = input_ids.shape[1]
279
+ max_length = num_input_tokens + max_new_tokens
280
+
281
+ block_size = self.block_size
282
+ output_ids = torch.full(
283
+ (1, max_length + block_size),
284
+ self.mask_token_id,
285
+ dtype=torch.long,
286
+ device=target.device,
287
+ )
288
+ position_ids = torch.arange(
289
+ output_ids.shape[1], device=target.device
290
+ ).unsqueeze(0)
291
+
292
+ past_key_values_target = DynamicCache()
293
+ past_key_values_draft = DynamicCache()
294
+
295
+ # Prefill stage
296
+ output = target(
297
+ input_ids,
298
+ position_ids=position_ids[:, :num_input_tokens],
299
+ past_key_values=past_key_values_target,
300
+ use_cache=True,
301
+ logits_to_keep=1,
302
+ output_hidden_states=True,
303
+ )
304
+
305
+ output_ids[:, :num_input_tokens] = input_ids
306
+ output_ids[:, num_input_tokens : num_input_tokens + 1] = sample(
307
+ output.logits, temperature
308
+ )
309
+ target_hidden = extract_context_feature(
310
+ output.hidden_states, self.target_layer_ids
311
+ )
312
+
313
+ # Decode stage
314
+ acceptance_lengths = []
315
+ start = input_ids.shape[1]
316
+ while start < max_length:
317
+ block_output_ids = output_ids[:, start : start + block_size].clone()
318
+ block_position_ids = position_ids[:, start : start + block_size]
319
+ noise_embedding = target.model.embed_tokens(block_output_ids)
320
+ draft_logits = target.lm_head(
321
+ self(
322
+ target_hidden=target_hidden,
323
+ noise_embedding=noise_embedding,
324
+ position_ids=position_ids[
325
+ :, past_key_values_draft.get_seq_length() : start + block_size
326
+ ],
327
+ past_key_values=past_key_values_draft,
328
+ use_cache=True,
329
+ is_causal=False,
330
+ )[:, -block_size + 1 :, :]
331
+ )
332
+ past_key_values_draft.crop(start)
333
+ block_output_ids[:, 1:] = sample(draft_logits)
334
+
335
+ output = target(
336
+ block_output_ids,
337
+ position_ids=block_position_ids,
338
+ past_key_values=past_key_values_target,
339
+ use_cache=True,
340
+ output_hidden_states=True,
341
+ )
342
+
343
+ posterior = sample(output.logits, temperature)
344
+ acceptance_length = (
345
+ (block_output_ids[:, 1:] == posterior[:, :-1])
346
+ .cumprod(dim=1)
347
+ .sum(dim=1)[0]
348
+ .item()
349
+ )
350
+ output_ids[:, start : start + acceptance_length + 1] = block_output_ids[
351
+ :, : acceptance_length + 1
352
+ ]
353
+ output_ids[:, start + acceptance_length + 1] = posterior[
354
+ :, acceptance_length
355
+ ]
356
+ start += acceptance_length + 1
357
+ past_key_values_target.crop(start)
358
+ target_hidden = extract_context_feature(
359
+ output.hidden_states, self.target_layer_ids
360
+ )[:, : acceptance_length + 1, :]
361
+ acceptance_lengths.append(acceptance_length + 1)
362
+ if stop_token_ids is not None and any(
363
+ stop_token_id in output_ids[:, num_input_tokens:]
364
+ for stop_token_id in stop_token_ids
365
+ ):
366
+ break
367
+ output_ids = output_ids[:, :max_length]
368
+ output_ids = output_ids[:, output_ids[0] != self.mask_token_id]
369
+ if stop_token_ids is not None:
370
+ stop_token_ids = torch.tensor(stop_token_ids, device=output_ids.device)
371
+ stop_token_indices = torch.isin(
372
+ output_ids[0][num_input_tokens:], stop_token_ids
373
+ ).nonzero(as_tuple=True)[0]
374
+ if stop_token_indices.numel() > 0:
375
+ output_ids = output_ids[
376
+ :, : num_input_tokens + stop_token_indices[0] + 1
377
+ ]
378
+
379
+ return output_ids
edit-dflash-hidden5-target5-block16-edit-hidden5/epoch_1_step_55000/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e94e6feca29409aa45d4ae851ea70c902a068ea1f477d1eb19e75f264b9209a
3
+ size 2097259104
edit-dflash-hidden5-target5-block16-edit-hidden5/epoch_1_step_55000/training_state.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90b265cf6d589dc09afc5ffa2cd8ef44365f58e64a826f39efaeba93b4f0f35b
3
+ size 2293306161