shanjiaz commited on
Commit
cc17261
·
verified ·
1 Parent(s): c1a25a5

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. config.json +58 -0
  2. eagle3.py +545 -0
  3. generation_config.json +4 -0
  4. model.safetensors +3 -0
config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Eagle3Speculator"
4
+ ],
5
+ "auto_map": {
6
+ "": "eagle3.Eagle3SpeculatorConfig"
7
+ },
8
+ "draft_vocab_size": 32000,
9
+ "eagle_aux_hidden_state_layer_ids": null,
10
+ "has_no_defaults_at_init": false,
11
+ "inference_type": null,
12
+ "norm_before_residual": false,
13
+ "speculators_config": {
14
+ "algorithm": "eagle3",
15
+ "default_proposal_method": "greedy",
16
+ "proposal_methods": [
17
+ {
18
+ "accept_tolerance": 0.0,
19
+ "proposal_type": "greedy",
20
+ "speculative_tokens": 5,
21
+ "verifier_accept_k": 1
22
+ }
23
+ ],
24
+ "verifier": {
25
+ "architectures": [
26
+ "LlamaForCausalLM"
27
+ ],
28
+ "name_or_path": "meta-llama/Meta-Llama-3.1-8B-Instruct"
29
+ }
30
+ },
31
+ "speculators_model_type": "eagle3",
32
+ "speculators_version": "0.2.0.dev16",
33
+ "target_hidden_size": null,
34
+ "torch_dtype": "float16",
35
+ "transformer_layer_config": {
36
+ "attention_bias": false,
37
+ "attention_dropout": 0.0,
38
+ "head_dim": 128,
39
+ "hidden_act": "silu",
40
+ "hidden_size": 4096,
41
+ "initializer_range": 0.02,
42
+ "intermediate_size": 14336,
43
+ "max_position_embeddings": 131072,
44
+ "mlp_bias": false,
45
+ "model_type": "llama",
46
+ "num_attention_heads": 32,
47
+ "num_hidden_layers": 1,
48
+ "num_key_value_heads": 8,
49
+ "pretraining_tp": 1,
50
+ "rms_norm_eps": 1e-05,
51
+ "rope_scaling": null,
52
+ "rope_theta": 10000.0,
53
+ "torch_dtype": "float16",
54
+ "use_cache": true,
55
+ "vocab_size": 128256
56
+ },
57
+ "transformers_version": "4.53.2"
58
+ }
eagle3.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Speculators implementation of EAGLE-3:
3
+ - https://arxiv.org/abs/2503.01840
4
+
5
+ Classes:
6
+ Eagle3SpeculatorConfig: Configuration class for EAGLE-3 speculator model
7
+ EagleSpeculator3: Main model implementation for EAGLE-3 speculators
8
+ Eagle3Attention: Custom attention layer for EAGLE-3, processes
9
+ concatenated embeddings and hidden states
10
+ Eagle3DecoderLayer: Custom decoder layer for EAGLE-3, processes
11
+ concatenated embeddings and hidden states with Eagle3Attention
12
+ and support for moving hidden layernorm before residual
13
+ """
14
+
15
+ import os
16
+ from typing import Any, ClassVar, Literal, Optional, Union
17
+
18
+ import torch
19
+ from pydantic import Field, field_serializer, field_validator
20
+ from torch import nn
21
+ from transformers import PretrainedConfig, PreTrainedModel
22
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
23
+ from transformers.modeling_outputs import CausalLMOutputWithPast
24
+ from transformers.models.llama.configuration_llama import LlamaConfig
25
+ from transformers.models.llama.modeling_llama import (
26
+ LlamaMLP,
27
+ LlamaRMSNorm,
28
+ apply_rotary_pos_emb,
29
+ repeat_kv,
30
+ )
31
+
32
+ from speculators import SpeculatorModel, SpeculatorModelConfig
33
+
34
+ __all__ = [
35
+ "Eagle3Attention",
36
+ "Eagle3DecoderLayer",
37
+ "Eagle3Speculator",
38
+ "Eagle3SpeculatorConfig",
39
+ ]
40
+
41
+
42
+ @SpeculatorModelConfig.register("eagle3")
43
+ class Eagle3SpeculatorConfig(SpeculatorModelConfig):
44
+ """
45
+ Configuration for EAGLE-3 speculator with vocabulary mapping.
46
+
47
+ EAGLE-3 features vocabulary mapping between draft (32K) and target (128K)
48
+ vocabularies, enabling cross-tokenizer speculation.
49
+
50
+ :param transformer_layer_config: Configuration for the transformer decoder layer
51
+ :param draft_vocab_size: Size of draft model vocabulary for speculation
52
+ :param norm_before_residual: Apply hidden_norm before storing residual
53
+ """
54
+
55
+ speculators_model_type: Literal["eagle3"] = "eagle3"
56
+ architectures: list[str] = Field(
57
+ default_factory=lambda: ["Eagle3Speculator"],
58
+ description="Model architectures that can load these weights",
59
+ )
60
+
61
+ transformer_layer_config: PretrainedConfig = Field(
62
+ default_factory=LlamaConfig,
63
+ description="Configuration for the transformer decoder layer",
64
+ )
65
+
66
+ draft_vocab_size: int = Field(
67
+ default=32000,
68
+ description="Size of draft model vocabulary for speculation",
69
+ )
70
+
71
+ norm_before_residual: bool = Field(
72
+ default=False,
73
+ description="Apply hidden_norm before storing residual",
74
+ )
75
+
76
+ target_hidden_size: Optional[int] = Field(
77
+ default=None,
78
+ description="Hidden size of the target model (if different from draft model)",
79
+ )
80
+
81
+ eagle_aux_hidden_state_layer_ids: Optional[list[int]] = Field(
82
+ default=None,
83
+ description="Layer IDs of the Eagle auxiliary hidden state layers",
84
+ )
85
+
86
+ inference_type: Optional[str] = Field(
87
+ default="text",
88
+ description="Inference type of the speculator",
89
+ )
90
+
91
+ @property
92
+ def target_vocab_size(self) -> int:
93
+ """Get target vocabulary size from transformer config."""
94
+ return self.transformer_layer_config.vocab_size
95
+
96
+ @field_serializer("transformer_layer_config")
97
+ def serialize_transformer_config(self, value: PretrainedConfig) -> dict:
98
+ """Serialize transformer config to dict."""
99
+ return value.to_diff_dict()
100
+
101
+ @field_validator("transformer_layer_config", mode="before")
102
+ @classmethod
103
+ def validate_transformer_config(cls, value: Any) -> PretrainedConfig:
104
+ """Validate and convert transformer config."""
105
+ if isinstance(value, dict):
106
+ config_class: type[PretrainedConfig] = LlamaConfig
107
+ if "model_type" in value:
108
+ from transformers import AutoConfig
109
+
110
+ config_class = AutoConfig.for_model(
111
+ model_type=value["model_type"]
112
+ ).__class__
113
+ return config_class(**value)
114
+ return value
115
+
116
+
117
+ class Eagle3Attention(nn.Module):
118
+ """
119
+ Eagle-3 attention module that processes concatenated embeddings and hidden states.
120
+
121
+ Modified from standard Llama attention to accept 2x hidden_size input
122
+ for Q/K/V projections while maintaining standard output size.
123
+ """
124
+
125
+ def __init__(self, config: PretrainedConfig, layer_idx: int):
126
+ super().__init__()
127
+ self.config = config
128
+ self.layer_idx = layer_idx
129
+
130
+ self.num_heads = config.num_attention_heads
131
+ self.num_key_value_heads = config.num_key_value_heads
132
+ self.hidden_size = config.hidden_size
133
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
134
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
135
+
136
+ input_size = 2 * self.hidden_size
137
+ self.q_proj = nn.Linear(
138
+ input_size, self.num_heads * self.head_dim, bias=config.attention_bias
139
+ )
140
+ self.k_proj = nn.Linear(
141
+ input_size,
142
+ self.num_key_value_heads * self.head_dim,
143
+ bias=config.attention_bias,
144
+ )
145
+ self.v_proj = nn.Linear(
146
+ input_size,
147
+ self.num_key_value_heads * self.head_dim,
148
+ bias=config.attention_bias,
149
+ )
150
+ self.o_proj = nn.Linear(
151
+ self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias
152
+ )
153
+
154
+ def forward(
155
+ self,
156
+ hidden_states: torch.Tensor,
157
+ attention_mask: Optional[torch.Tensor] = None,
158
+ position_ids: Optional[torch.LongTensor] = None,
159
+ past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
160
+ output_attentions: bool = False,
161
+ use_cache: bool = False,
162
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
163
+ **kwargs, # noqa: ARG002
164
+ ) -> tuple:
165
+ """
166
+ Forward pass for Eagle-3 attention.
167
+ Taken from Llama Attention but modified to accept 2x hidden_size input.
168
+
169
+ :param hidden_states: Input tensor of shape [batch, seq_len, 2*hidden_size]
170
+ :param attention_mask: Optional attention mask
171
+ :param position_ids: Optional position IDs for rotary embeddings
172
+ :param past_key_value: Optional cached key-value pairs
173
+ :param output_attentions: Whether to return attention weights
174
+ :param use_cache: Whether to cache key-value pairs
175
+ :param position_embeddings: Optional precomputed rotary embeddings
176
+ :return: Tuple of (hidden_states, [attention_weights], [past_key_value])
177
+ """
178
+ bsz, q_len, _ = hidden_states.size()
179
+
180
+ query_states = self.q_proj(hidden_states)
181
+ key_states = self.k_proj(hidden_states)
182
+ value_states = self.v_proj(hidden_states)
183
+
184
+ query_states = query_states.view(
185
+ bsz, q_len, self.num_heads, self.head_dim
186
+ ).transpose(1, 2)
187
+ key_states = key_states.view(
188
+ bsz, q_len, self.num_key_value_heads, self.head_dim
189
+ ).transpose(1, 2)
190
+ value_states = value_states.view(
191
+ bsz, q_len, self.num_key_value_heads, self.head_dim
192
+ ).transpose(1, 2)
193
+
194
+ if position_embeddings is not None:
195
+ cos, sin = position_embeddings
196
+ query_states, key_states = apply_rotary_pos_emb(
197
+ query_states, key_states, cos, sin, position_ids
198
+ )
199
+
200
+ past_key_value_out = None
201
+ if past_key_value is not None:
202
+ past_key = past_key_value[0]
203
+ past_value = past_key_value[1]
204
+ key_states = torch.cat([past_key, key_states], dim=2)
205
+ value_states = torch.cat([past_value, value_states], dim=2)
206
+
207
+ if use_cache:
208
+ past_key_value_out = (key_states, value_states)
209
+
210
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
211
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
212
+
213
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / (
214
+ self.head_dim**0.5
215
+ )
216
+
217
+ if attention_mask is not None:
218
+ attn_weights = attn_weights + attention_mask
219
+
220
+ attn_weights = nn.functional.softmax(
221
+ attn_weights, dim=-1, dtype=torch.float32
222
+ ).to(query_states.dtype)
223
+
224
+ attn_output = torch.matmul(attn_weights, value_states)
225
+ attn_output = attn_output.transpose(1, 2).contiguous()
226
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
227
+
228
+ attn_output = self.o_proj(attn_output)
229
+
230
+ if not output_attentions:
231
+ attn_weights = None
232
+
233
+ return attn_output, attn_weights, past_key_value_out
234
+
235
+
236
+ class Eagle3DecoderLayer(nn.Module):
237
+ """
238
+ Eagle-3 decoder layer that processes concatenated embeddings and hidden states.
239
+
240
+ Accepts 2x hidden_size input from concatenated embeddings and fused hidden states.
241
+ Uses Eagle3Attention for the self-attention computation.
242
+ """
243
+
244
+ def __init__(
245
+ self,
246
+ config: PretrainedConfig,
247
+ layer_idx: int,
248
+ norm_before_residual: bool = False,
249
+ ):
250
+ super().__init__()
251
+ self.hidden_size = config.hidden_size
252
+ self.norm_before_residual = norm_before_residual
253
+
254
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
255
+ self.hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
256
+ self.post_attention_layernorm = LlamaRMSNorm(
257
+ config.hidden_size, eps=config.rms_norm_eps
258
+ )
259
+
260
+ self.self_attn = Eagle3Attention(config, layer_idx)
261
+
262
+ self.mlp = LlamaMLP(config)
263
+
264
+ def forward(
265
+ self,
266
+ hidden_states: torch.Tensor,
267
+ attention_mask: Optional[torch.Tensor] = None,
268
+ position_ids: Optional[torch.LongTensor] = None,
269
+ past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
270
+ output_attentions: Optional[bool] = False,
271
+ use_cache: Optional[bool] = False,
272
+ cache_position: Optional[torch.LongTensor] = None, # noqa: ARG002
273
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
274
+ **kwargs, # noqa: ARG002
275
+ ) -> tuple:
276
+ """
277
+ Process concatenated embeddings and hidden states through modified decoder
278
+ layer.
279
+
280
+ :param hidden_states: Input tensor of shape [batch, seq_len, 2*hidden_size]
281
+ :return: Tuple of layer outputs
282
+ """
283
+ embeds = hidden_states[:, :, : self.hidden_size]
284
+ hidden = hidden_states[:, :, self.hidden_size : 2 * self.hidden_size]
285
+
286
+ if self.norm_before_residual:
287
+ hidden = self.hidden_norm(hidden)
288
+ residual = hidden
289
+ else:
290
+ residual = hidden
291
+ hidden = self.hidden_norm(hidden)
292
+
293
+ embeds = self.input_layernorm(embeds)
294
+
295
+ attn_input = torch.cat([embeds, hidden], dim=-1)
296
+
297
+ attn_output, attn_weights, past_key_value_out = self.self_attn(
298
+ hidden_states=attn_input,
299
+ attention_mask=attention_mask,
300
+ position_ids=position_ids,
301
+ past_key_value=past_key_value,
302
+ output_attentions=output_attentions,
303
+ use_cache=use_cache,
304
+ position_embeddings=position_embeddings,
305
+ )
306
+
307
+ hidden_states = residual + attn_output
308
+
309
+ residual = hidden_states
310
+ hidden_states = self.post_attention_layernorm(hidden_states)
311
+ hidden_states = self.mlp(hidden_states)
312
+ hidden_states = residual + hidden_states
313
+
314
+ outputs = (hidden_states,)
315
+
316
+ if output_attentions:
317
+ outputs += (attn_weights,) # type: ignore[assignment]
318
+
319
+ if use_cache:
320
+ outputs += (past_key_value_out,) # type: ignore[assignment]
321
+
322
+ return outputs
323
+
324
+
325
+ @SpeculatorModel.register("eagle3")
326
+ class Eagle3Speculator(SpeculatorModel):
327
+ """
328
+ EAGLE-3 speculator with vocabulary mapping and multi-layer fusion.
329
+
330
+ EAGLE-3 processes concatenated hidden states from multiple verifier layers
331
+ through a fusion layer, then combines with embeddings for a custom decoder
332
+ layer that accepts 2x hidden_size input.
333
+ """
334
+
335
+ config_class: ClassVar[type[Eagle3SpeculatorConfig]] = Eagle3SpeculatorConfig # type: ignore[misc]
336
+ _keys_to_ignore_on_load_missing: ClassVar[list[str]] = [ # type: ignore[misc]
337
+ "verifier*",
338
+ ]
339
+ _keys_to_ignore_on_save: ClassVar[list[str]] = [] # type: ignore[misc,assignment]
340
+
341
+ def __init__(
342
+ self,
343
+ config: Eagle3SpeculatorConfig,
344
+ verifier: Optional[Union[str, os.PathLike, PreTrainedModel]] = None,
345
+ verifier_attachment_mode: Optional[
346
+ Literal["detached", "full", "train_only"]
347
+ ] = None,
348
+ ):
349
+ """
350
+ Initialize Eagle3 speculator.
351
+
352
+ :param config: Eagle3SpeculatorConfig instance
353
+ :param verifier: Optional verifier model
354
+ :param verifier_attachment_mode: How to attach the verifier
355
+ """
356
+ if not isinstance(config, Eagle3SpeculatorConfig):
357
+ raise ValueError(
358
+ f"config must be Eagle3SpeculatorConfig, got {type(config)}"
359
+ )
360
+
361
+ self.config: Eagle3SpeculatorConfig = config
362
+
363
+ self.hidden_size = config.transformer_layer_config.hidden_size
364
+ self.draft_vocab_size = config.draft_vocab_size
365
+ self.target_vocab_size = config.target_vocab_size
366
+
367
+ # Use target_hidden_size if specified, otherwise use draft model's hidden_size
368
+ self.target_hidden_size = (
369
+ config.target_hidden_size
370
+ if config.target_hidden_size is not None
371
+ else self.hidden_size
372
+ )
373
+
374
+ super().__init__(
375
+ config=config,
376
+ verifier=verifier,
377
+ verifier_attachment_mode=verifier_attachment_mode,
378
+ )
379
+
380
+ self.embed_tokens = nn.Embedding(
381
+ self.target_vocab_size,
382
+ self.hidden_size,
383
+ padding_idx=config.transformer_layer_config.pad_token_id
384
+ if hasattr(config.transformer_layer_config, "pad_token_id")
385
+ else None,
386
+ )
387
+
388
+ self.fc = nn.Linear(
389
+ 3 * self.target_hidden_size, # Use target model's hidden size
390
+ self.hidden_size,
391
+ bias=False,
392
+ )
393
+
394
+ self.layers = nn.ModuleList(
395
+ [
396
+ Eagle3DecoderLayer(
397
+ config.transformer_layer_config,
398
+ layer_idx=0,
399
+ norm_before_residual=config.norm_before_residual,
400
+ )
401
+ ]
402
+ )
403
+
404
+ self.norm = LlamaRMSNorm(
405
+ self.hidden_size,
406
+ eps=config.transformer_layer_config.rms_norm_eps,
407
+ )
408
+
409
+ self.lm_head = nn.Linear(
410
+ self.hidden_size,
411
+ self.draft_vocab_size,
412
+ bias=False,
413
+ )
414
+
415
+ self.post_init() # type: ignore[attr-defined]
416
+
417
+ def forward(
418
+ self,
419
+ input_ids: torch.LongTensor,
420
+ hidden_states: torch.FloatTensor,
421
+ attention_mask: Optional[torch.Tensor] = None,
422
+ position_ids: Optional[torch.LongTensor] = None,
423
+ past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
424
+ use_cache: Optional[bool] = None,
425
+ output_attentions: Optional[bool] = None,
426
+ output_hidden_states: Optional[bool] = None, # noqa: ARG002
427
+ return_dict: Optional[bool] = None,
428
+ ) -> Union[torch.FloatTensor, CausalLMOutputWithPast]:
429
+ """
430
+ Forward pass for EAGLE-3 speculation.
431
+
432
+ :param input_ids: Input token IDs from draft vocabulary
433
+ :param hidden_states: Concatenated hidden states from 3 verifier layers
434
+ [B, L, 3*target_H] where target_H is the target model's hidden size
435
+ :param attention_mask: Optional attention mask
436
+ :param position_ids: Optional position IDs
437
+ :param past_key_values: Optional cached key-values
438
+ :param use_cache: Whether to cache key-values
439
+ :param output_attentions: Return attention weights
440
+ :param output_hidden_states: Return hidden states
441
+ :param return_dict: Return dict output
442
+ :return: Model outputs with draft vocabulary logits
443
+ """
444
+ return_dict = (
445
+ return_dict if return_dict is not None else self.config.use_return_dict
446
+ )
447
+
448
+ inputs_embeds = self.embed_tokens(input_ids)
449
+
450
+ fused_hidden = self.fc(hidden_states)
451
+
452
+ layer_input = torch.cat([inputs_embeds, fused_hidden], dim=-1)
453
+
454
+ batch_size, seq_length = layer_input.shape[:2]
455
+ if attention_mask is not None and attention_mask.dim() == 2: # noqa: PLR2004
456
+ past_key_values_length = (
457
+ past_key_values[0][0].shape[2] if past_key_values else 0
458
+ )
459
+ attention_mask = _prepare_4d_causal_attention_mask(
460
+ attention_mask,
461
+ (batch_size, seq_length),
462
+ hidden_states,
463
+ past_key_values_length,
464
+ )
465
+
466
+ if position_ids is None:
467
+ device = hidden_states.device
468
+ position_ids = (
469
+ torch.arange( # type: ignore[assignment]
470
+ seq_length, dtype=torch.long, device=device
471
+ )
472
+ .unsqueeze(0)
473
+ .expand(batch_size, -1)
474
+ )
475
+
476
+ layer_outputs = self.layers[0](
477
+ layer_input,
478
+ attention_mask=attention_mask,
479
+ position_ids=position_ids,
480
+ past_key_value=past_key_values[0] if past_key_values else None,
481
+ output_attentions=output_attentions,
482
+ use_cache=use_cache,
483
+ )
484
+
485
+ hidden_states = layer_outputs[0]
486
+
487
+ hidden_states = self.norm(hidden_states)
488
+
489
+ logits = self.compute_logits(hidden_states, map_to_target_vocab=True)
490
+
491
+ if not return_dict:
492
+ return logits
493
+
494
+ return CausalLMOutputWithPast(
495
+ logits=logits,
496
+ past_key_values=[layer_outputs[1]] if use_cache else None, # type: ignore[arg-type]
497
+ hidden_states=None,
498
+ attentions=None,
499
+ )
500
+
501
+ def compute_logits(
502
+ self,
503
+ hidden_states: torch.FloatTensor,
504
+ map_to_target_vocab: bool = True,
505
+ ) -> torch.FloatTensor:
506
+ """
507
+ Compute logits with optional vocabulary mapping.
508
+
509
+ :param hidden_states: Hidden states from the model
510
+ :param map_to_target_vocab: Whether to map draft logits to target vocabulary
511
+ :return: Logits tensor
512
+ """
513
+ logits = self.lm_head(hidden_states)
514
+
515
+ if not map_to_target_vocab:
516
+ return logits
517
+
518
+ batch_size, seq_length, _ = logits.shape
519
+
520
+ draft_indices = torch.arange(self.draft_vocab_size, device=logits.device)
521
+
522
+ target_indices = draft_indices + self.d2t
523
+
524
+ mapped_logits = logits.new_full(
525
+ (batch_size, seq_length, self.target_vocab_size), float("-inf")
526
+ )
527
+
528
+ mapped_logits[:, :, target_indices] = logits
529
+
530
+ return mapped_logits
531
+
532
+ def tie_weights(self):
533
+ """
534
+ Override tie_weights to prevent vocabulary corruption in transformers 4.54.1+
535
+
536
+ Eagle3 intentionally uses different vocabulary sizes:
537
+ - Input embeddings (embed_tokens): 128256 (full vocabulary)
538
+ - Output embeddings (lm_head): 32000 (draft vocabulary)
539
+
540
+ The default tie_weights() tries to make them identical, breaking Eagle3.
541
+ This override preserves the intentional vocabulary size difference.
542
+ """
543
+ # Don't call super().tie_weights() - this prevents vocabulary corruption
544
+ # that occurs when _tie_or_clone_weights replaces lm_head.weight with
545
+ # embed_tokens.weight
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.53.2"
4
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:976aeac1799e62415d9384ca292913f95f2ae3b6f3c7722f7a99986e02a95f55
3
+ size 1900053960