alexmarques commited on
Commit
cfa3dff
·
verified ·
1 Parent(s): dc05f1c

Upload eagle3.py

Browse files
Files changed (1) hide show
  1. eagle3.py +69 -173
eagle3.py CHANGED
@@ -13,14 +13,12 @@ Classes:
13
  """
14
 
15
  import os
16
- from typing import Any, ClassVar, Literal, Optional, Union
17
 
18
  import torch
19
  from pydantic import Field, field_serializer, field_validator
20
  from torch import nn
21
- from transformers import PretrainedConfig, PreTrainedModel
22
- from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
23
- from transformers.modeling_outputs import CausalLMOutputWithPast
24
  from transformers.models.llama.configuration_llama import LlamaConfig
25
  from transformers.models.llama.modeling_llama import (
26
  LlamaMLP,
@@ -73,11 +71,16 @@ class Eagle3SpeculatorConfig(SpeculatorModelConfig):
73
  description="Apply hidden_norm before storing residual",
74
  )
75
 
76
- target_hidden_size: Optional[int] = Field(
77
  default=None,
78
  description="Hidden size of the target model (if different from draft model)",
79
  )
80
 
 
 
 
 
 
81
  @property
82
  def target_vocab_size(self) -> int:
83
  """Get target vocabulary size from transformer config."""
@@ -95,8 +98,6 @@ class Eagle3SpeculatorConfig(SpeculatorModelConfig):
95
  if isinstance(value, dict):
96
  config_class: type[PretrainedConfig] = LlamaConfig
97
  if "model_type" in value:
98
- from transformers import AutoConfig
99
-
100
  config_class = AutoConfig.for_model(
101
  model_type=value["model_type"]
102
  ).__class__
@@ -144,12 +145,12 @@ class Eagle3Attention(nn.Module):
144
  def forward(
145
  self,
146
  hidden_states: torch.Tensor,
147
- attention_mask: Optional[torch.Tensor] = None,
148
- position_ids: Optional[torch.LongTensor] = None,
149
- past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
150
  output_attentions: bool = False,
151
  use_cache: bool = False,
152
- position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
153
  **kwargs, # noqa: ARG002
154
  ) -> tuple:
155
  """
@@ -254,13 +255,13 @@ class Eagle3DecoderLayer(nn.Module):
254
  def forward(
255
  self,
256
  hidden_states: torch.Tensor,
257
- attention_mask: Optional[torch.Tensor] = None,
258
- position_ids: Optional[torch.LongTensor] = None,
259
- past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
260
- output_attentions: Optional[bool] = False,
261
- use_cache: Optional[bool] = False,
262
- cache_position: Optional[torch.LongTensor] = None, # noqa: ARG002
263
- position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
264
  **kwargs, # noqa: ARG002
265
  ) -> tuple:
266
  """
@@ -331,10 +332,11 @@ class Eagle3Speculator(SpeculatorModel):
331
  def __init__(
332
  self,
333
  config: Eagle3SpeculatorConfig,
334
- verifier: Optional[Union[str, os.PathLike, PreTrainedModel]] = None,
335
- verifier_attachment_mode: Optional[
336
- Literal["detached", "full", "train_only"]
337
- ] = None,
 
338
  ):
339
  """
340
  Initialize Eagle3 speculator.
@@ -342,6 +344,8 @@ class Eagle3Speculator(SpeculatorModel):
342
  :param config: Eagle3SpeculatorConfig instance
343
  :param verifier: Optional verifier model
344
  :param verifier_attachment_mode: How to attach the verifier
 
 
345
  """
346
  if not isinstance(config, Eagle3SpeculatorConfig):
347
  raise ValueError(
@@ -367,13 +371,14 @@ class Eagle3Speculator(SpeculatorModel):
367
  verifier_attachment_mode=verifier_attachment_mode,
368
  )
369
 
370
- self.embed_tokens = nn.Embedding(
371
- self.target_vocab_size,
372
- self.hidden_size,
373
- padding_idx=config.transformer_layer_config.pad_token_id
374
- if hasattr(config.transformer_layer_config, "pad_token_id")
375
- else None,
376
- )
 
377
 
378
  self.fc = nn.Linear(
379
  3 * self.target_hidden_size, # Use target model's hidden size
@@ -401,34 +406,48 @@ class Eagle3Speculator(SpeculatorModel):
401
  self.draft_vocab_size,
402
  bias=False,
403
  )
 
 
 
 
 
 
 
 
 
404
 
405
- self.register_buffer( # type: ignore[attr-defined]
406
- "d2t",
407
- torch.zeros(self.draft_vocab_size, dtype=torch.long),
408
- )
409
- self.register_buffer( # type: ignore[attr-defined]
410
- "t2d",
411
- torch.zeros(self.target_vocab_size, dtype=torch.bool),
412
- )
413
 
414
- # Type hints for buffers
415
- self.d2t: torch.Tensor
416
- self.t2d: torch.Tensor
417
 
418
- self.post_init() # type: ignore[attr-defined]
 
 
 
 
 
 
 
 
 
419
 
420
  def forward(
421
  self,
422
  input_ids: torch.LongTensor,
423
  hidden_states: torch.FloatTensor,
424
- attention_mask: Optional[torch.Tensor] = None,
425
- position_ids: Optional[torch.LongTensor] = None,
426
- past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
427
- use_cache: Optional[bool] = None,
428
- output_attentions: Optional[bool] = None,
429
- output_hidden_states: Optional[bool] = None, # noqa: ARG002
430
- return_dict: Optional[bool] = None,
431
- ) -> Union[torch.FloatTensor, CausalLMOutputWithPast]:
432
  """
433
  Forward pass for EAGLE-3 speculation.
434
 
@@ -444,127 +463,4 @@ class Eagle3Speculator(SpeculatorModel):
444
  :param return_dict: Return dict output
445
  :return: Model outputs with draft vocabulary logits
446
  """
447
- return_dict = (
448
- return_dict if return_dict is not None else self.config.use_return_dict
449
- )
450
-
451
- inputs_embeds = self.embed_tokens(input_ids)
452
-
453
- fused_hidden = self.fc(hidden_states)
454
-
455
- layer_input = torch.cat([inputs_embeds, fused_hidden], dim=-1)
456
-
457
- batch_size, seq_length = layer_input.shape[:2]
458
- if attention_mask is not None and attention_mask.dim() == 2: # noqa: PLR2004
459
- past_key_values_length = (
460
- past_key_values[0][0].shape[2] if past_key_values else 0
461
- )
462
- attention_mask = _prepare_4d_causal_attention_mask(
463
- attention_mask,
464
- (batch_size, seq_length),
465
- hidden_states,
466
- past_key_values_length,
467
- )
468
-
469
- if position_ids is None:
470
- device = hidden_states.device
471
- position_ids = (
472
- torch.arange( # type: ignore[assignment]
473
- seq_length, dtype=torch.long, device=device
474
- )
475
- .unsqueeze(0)
476
- .expand(batch_size, -1)
477
- )
478
-
479
- layer_outputs = self.layers[0](
480
- layer_input,
481
- attention_mask=attention_mask,
482
- position_ids=position_ids,
483
- past_key_value=past_key_values[0] if past_key_values else None,
484
- output_attentions=output_attentions,
485
- use_cache=use_cache,
486
- )
487
-
488
- hidden_states = layer_outputs[0]
489
-
490
- hidden_states = self.norm(hidden_states)
491
-
492
- logits = self.compute_logits(hidden_states, map_to_target_vocab=True)
493
-
494
- if not return_dict:
495
- return logits
496
-
497
- return CausalLMOutputWithPast(
498
- logits=logits,
499
- past_key_values=[layer_outputs[1]] if use_cache else None, # type: ignore[arg-type]
500
- hidden_states=None,
501
- attentions=None,
502
- )
503
-
504
- def compute_logits(
505
- self,
506
- hidden_states: torch.FloatTensor,
507
- map_to_target_vocab: bool = True,
508
- ) -> torch.FloatTensor:
509
- """
510
- Compute logits with optional vocabulary mapping.
511
-
512
- :param hidden_states: Hidden states from the model
513
- :param map_to_target_vocab: Whether to map draft logits to target vocabulary
514
- :return: Logits tensor
515
- """
516
- logits = self.lm_head(hidden_states)
517
-
518
- if not map_to_target_vocab:
519
- return logits
520
-
521
- batch_size, seq_length, _ = logits.shape
522
-
523
- draft_indices = torch.arange(self.draft_vocab_size, device=logits.device)
524
-
525
- target_indices = draft_indices + self.d2t
526
-
527
- mapped_logits = logits.new_full(
528
- (batch_size, seq_length, self.target_vocab_size), float("-inf")
529
- )
530
-
531
- mapped_logits[:, :, target_indices] = logits
532
-
533
- return mapped_logits
534
-
535
- def map_draft_to_target_tokens(
536
- self, draft_tokens: torch.LongTensor
537
- ) -> torch.LongTensor:
538
- """
539
- Map draft token IDs to target token IDs.
540
-
541
- :param draft_tokens: Draft vocabulary token IDs
542
- :return: Target vocabulary token IDs
543
- """
544
- return draft_tokens + self.d2t[draft_tokens] # type: ignore[return-value]
545
-
546
- def check_target_token_availability(
547
- self, target_tokens: torch.LongTensor
548
- ) -> torch.BoolTensor:
549
- """
550
- Check if target tokens have draft equivalents.
551
-
552
- :param target_tokens: Target vocabulary token IDs
553
- :return: Boolean mask indicating availability in draft vocabulary
554
- """
555
- return self.t2d[target_tokens] # type: ignore[return-value]
556
-
557
- def tie_weights(self):
558
- """
559
- Override tie_weights to prevent vocabulary corruption in transformers 4.54.1+
560
-
561
- Eagle3 intentionally uses different vocabulary sizes:
562
- - Input embeddings (embed_tokens): 128256 (full vocabulary)
563
- - Output embeddings (lm_head): 32000 (draft vocabulary)
564
-
565
- The default tie_weights() tries to make them identical, breaking Eagle3.
566
- This override preserves the intentional vocabulary size difference.
567
- """
568
- # Don't call super().tie_weights() - this prevents vocabulary corruption
569
- # that occurs when _tie_or_clone_weights replaces lm_head.weight with
570
- # embed_tokens.weight
 
13
  """
14
 
15
  import os
16
+ from typing import Any, ClassVar, Literal
17
 
18
  import torch
19
  from pydantic import Field, field_serializer, field_validator
20
  from torch import nn
21
+ from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
 
 
22
  from transformers.models.llama.configuration_llama import LlamaConfig
23
  from transformers.models.llama.modeling_llama import (
24
  LlamaMLP,
 
71
  description="Apply hidden_norm before storing residual",
72
  )
73
 
74
+ target_hidden_size: int | None = Field(
75
  default=None,
76
  description="Hidden size of the target model (if different from draft model)",
77
  )
78
 
79
+ eagle_aux_hidden_state_layer_ids: list[int] | None = Field(
80
+ default=None,
81
+ description="Layer IDs of the Eagle auxiliary hidden state layers",
82
+ )
83
+
84
  @property
85
  def target_vocab_size(self) -> int:
86
  """Get target vocabulary size from transformer config."""
 
98
  if isinstance(value, dict):
99
  config_class: type[PretrainedConfig] = LlamaConfig
100
  if "model_type" in value:
 
 
101
  config_class = AutoConfig.for_model(
102
  model_type=value["model_type"]
103
  ).__class__
 
145
  def forward(
146
  self,
147
  hidden_states: torch.Tensor,
148
+ attention_mask: torch.Tensor | None = None,
149
+ position_ids: torch.LongTensor | None = None,
150
+ past_key_value: tuple[torch.Tensor, torch.Tensor] | None = None,
151
  output_attentions: bool = False,
152
  use_cache: bool = False,
153
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
154
  **kwargs, # noqa: ARG002
155
  ) -> tuple:
156
  """
 
255
  def forward(
256
  self,
257
  hidden_states: torch.Tensor,
258
+ attention_mask: torch.Tensor | None = None,
259
+ position_ids: torch.LongTensor | None = None,
260
+ past_key_value: tuple[torch.Tensor, torch.Tensor] | None = None,
261
+ output_attentions: bool | None = False,
262
+ use_cache: bool | None = False,
263
+ cache_position: torch.LongTensor | None = None, # noqa: ARG002
264
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
265
  **kwargs, # noqa: ARG002
266
  ) -> tuple:
267
  """
 
332
  def __init__(
333
  self,
334
  config: Eagle3SpeculatorConfig,
335
+ verifier: str | os.PathLike | PreTrainedModel | None = None,
336
+ verifier_attachment_mode: Literal["detached", "full", "train_only"]
337
+ | None = None,
338
+ reduce_vocab_size: bool = True,
339
+ has_drafter_embedding: bool = True,
340
  ):
341
  """
342
  Initialize Eagle3 speculator.
 
344
  :param config: Eagle3SpeculatorConfig instance
345
  :param verifier: Optional verifier model
346
  :param verifier_attachment_mode: How to attach the verifier
347
+ :param reduce_vocab_size: Whether to reduce vocabulary size with mapping
348
+ :param has_drafter_embedding: Whether drafter embedding weights are provided
349
  """
350
  if not isinstance(config, Eagle3SpeculatorConfig):
351
  raise ValueError(
 
371
  verifier_attachment_mode=verifier_attachment_mode,
372
  )
373
 
374
+ if has_drafter_embedding:
375
+ self.embed_tokens = nn.Embedding(
376
+ self.target_vocab_size,
377
+ self.hidden_size,
378
+ padding_idx=config.transformer_layer_config.pad_token_id
379
+ if hasattr(config.transformer_layer_config, "pad_token_id")
380
+ else None,
381
+ )
382
 
383
  self.fc = nn.Linear(
384
  3 * self.target_hidden_size, # Use target model's hidden size
 
406
  self.draft_vocab_size,
407
  bias=False,
408
  )
409
+ if reduce_vocab_size:
410
+ self.register_buffer( # type: ignore[attr-defined]
411
+ "d2t",
412
+ torch.zeros(self.draft_vocab_size, dtype=torch.long),
413
+ )
414
+ self.register_buffer( # type: ignore[attr-defined]
415
+ "t2d",
416
+ torch.zeros(self.target_vocab_size, dtype=torch.bool),
417
+ )
418
 
419
+ # Type hints for buffers
420
+ self.d2t: torch.Tensor
421
+ self.t2d: torch.Tensor
422
+ self.post_init() # type: ignore[attr-defined]
 
 
 
 
423
 
424
+ def tie_weights(self):
425
+ """
426
+ Override tie_weights to prevent vocabulary corruption in transformers 4.54.1+
427
 
428
+ Eagle3 intentionally uses different vocabulary sizes:
429
+ - Input embeddings (embed_tokens): 128256 (full vocabulary)
430
+ - Output embeddings (lm_head): 32000 (draft vocabulary)
431
+
432
+ The default tie_weights() tries to make them identical, breaking Eagle3.
433
+ This override preserves the intentional vocabulary size difference.
434
+ """
435
+ # Don't call super().tie_weights() - this prevents vocabulary corruption
436
+ # that occurs when _tie_or_clone_weights replaces lm_head.weight with
437
+ # embed_tokens.weight
438
 
439
  def forward(
440
  self,
441
  input_ids: torch.LongTensor,
442
  hidden_states: torch.FloatTensor,
443
+ attention_mask: torch.Tensor | None = None,
444
+ position_ids: torch.LongTensor | None = None,
445
+ past_key_values: tuple[tuple[torch.FloatTensor]] | None = None,
446
+ use_cache: bool | None = None,
447
+ output_attentions: bool | None = None,
448
+ output_hidden_states: bool | None = None, # noqa: ARG002
449
+ return_dict: bool | None = None,
450
+ ) -> torch.FloatTensor:
451
  """
452
  Forward pass for EAGLE-3 speculation.
453
 
 
463
  :param return_dict: Return dict output
464
  :return: Model outputs with draft vocabulary logits
465
  """
466
+ raise NotImplementedError("Eagle3Speculator.forward is not implemented yet.")