RelaxingSnorlax commited on
Commit
2bcfb08
·
verified ·
1 Parent(s): aa996f9

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. config.json +1 -0
  2. eagle3.py +14 -2
config.json CHANGED
@@ -28,6 +28,7 @@
28
  },
29
  "speculators_model_type": "eagle3",
30
  "speculators_version": "0.1.0.dev13",
 
31
  "torch_dtype": "float32",
32
  "transformer_layer_config": {
33
  "attention_bias": false,
 
28
  },
29
  "speculators_model_type": "eagle3",
30
  "speculators_version": "0.1.0.dev13",
31
+ "target_hidden_size": null,
32
  "torch_dtype": "float32",
33
  "transformer_layer_config": {
34
  "attention_bias": false,
eagle3.py CHANGED
@@ -73,6 +73,11 @@ class Eagle3SpeculatorConfig(SpeculatorModelConfig):
73
  description="Apply hidden_norm before storing residual",
74
  )
75
 
 
 
 
 
 
76
  @property
77
  def target_vocab_size(self) -> int:
78
  """Get target vocabulary size from transformer config."""
@@ -349,6 +354,13 @@ class Eagle3Speculator(SpeculatorModel):
349
  self.draft_vocab_size = config.draft_vocab_size
350
  self.target_vocab_size = config.target_vocab_size
351
 
 
 
 
 
 
 
 
352
  super().__init__(
353
  config=config,
354
  verifier=verifier,
@@ -364,7 +376,7 @@ class Eagle3Speculator(SpeculatorModel):
364
  )
365
 
366
  self.fc = nn.Linear(
367
- 3 * self.hidden_size,
368
  self.hidden_size,
369
  bias=False,
370
  )
@@ -422,7 +434,7 @@ class Eagle3Speculator(SpeculatorModel):
422
 
423
  :param input_ids: Input token IDs from draft vocabulary
424
  :param hidden_states: Concatenated hidden states from 3 verifier layers
425
- [B, L, 3*H]
426
  :param attention_mask: Optional attention mask
427
  :param position_ids: Optional position IDs
428
  :param past_key_values: Optional cached key-values
 
73
  description="Apply hidden_norm before storing residual",
74
  )
75
 
76
+ target_hidden_size: Optional[int] = Field(
77
+ default=None,
78
+ description="Hidden size of the target model (if different from draft model)",
79
+ )
80
+
81
  @property
82
  def target_vocab_size(self) -> int:
83
  """Get target vocabulary size from transformer config."""
 
354
  self.draft_vocab_size = config.draft_vocab_size
355
  self.target_vocab_size = config.target_vocab_size
356
 
357
+ # Use target_hidden_size if specified, otherwise use draft model's hidden_size
358
+ self.target_hidden_size = (
359
+ config.target_hidden_size
360
+ if config.target_hidden_size is not None
361
+ else self.hidden_size
362
+ )
363
+
364
  super().__init__(
365
  config=config,
366
  verifier=verifier,
 
376
  )
377
 
378
  self.fc = nn.Linear(
379
+ 3 * self.target_hidden_size, # Use target model's hidden size
380
  self.hidden_size,
381
  bias=False,
382
  )
 
434
 
435
  :param input_ids: Input token IDs from draft vocabulary
436
  :param hidden_states: Concatenated hidden states from 3 verifier layers
437
+ [B, L, 3*target_H] where target_H is the target model's hidden size
438
  :param attention_mask: Optional attention mask
439
  :param position_ids: Optional position IDs
440
  :param past_key_values: Optional cached key-values