maxholsman commited on
Commit
2cac9f5
·
verified ·
1 Parent(s): 83f939c

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. custom_generate/generate.py +152 -68
custom_generate/generate.py CHANGED
@@ -62,6 +62,10 @@ class GenerateDecoderOnlyOutput(ModelOutput):
62
  attentions: tuple[tuple[torch.FloatTensor]] | None = None
63
  hidden_states: tuple[tuple[torch.FloatTensor]] | None = None
64
  past_key_values: Cache | None = None
 
 
 
 
65
 
66
 
67
  @dataclass
@@ -77,6 +81,10 @@ class GenerateEncoderDecoderOutput(ModelOutput):
77
  cross_attentions: tuple[tuple[torch.FloatTensor]] | None = None
78
  decoder_hidden_states: tuple[tuple[torch.FloatTensor]] | None = None
79
  past_key_values: Cache | None = None
 
 
 
 
80
 
81
 
82
  def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False):
@@ -115,9 +123,16 @@ class RawLogitsCandidateGenerator(AssistedCandidateGenerator):
115
  """Initialize the custom candidate generator."""
116
  super().__init__(*args, **kwargs)
117
  # Initialize probs list if sklearn is available and confidence threshold is enabled
 
 
 
 
 
 
118
  if (
119
  is_sklearn_available()
120
- and self.assistant_generation_config.assistant_confidence_threshold
 
121
  ):
122
  if not hasattr(self, 'probs'):
123
  self.probs = []
@@ -149,9 +164,15 @@ class RawLogitsCandidateGenerator(AssistedCandidateGenerator):
149
  self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
150
 
151
  # Handle sklearn confidence threshold tracking (if enabled)
 
 
 
 
 
152
  if (
153
  is_sklearn_available()
154
- and self.assistant_generation_config.assistant_confidence_threshold
 
155
  and type(self) is RawLogitsCandidateGenerator
156
  ):
157
  scores_tensor = torch.cat(assistant_output.scores, dim=0)
@@ -181,7 +202,7 @@ def _speculative_sampling(
181
  is_done_candidate,
182
  candidate_logits_raw,
183
  fsd_threshold: float = 0.0,
184
- fsd_div_type: str = "kl"
185
  ):
186
  """
187
  Applies sampling as in the speculative decoding paper (https://huggingface.co/papers/2211.17192, algorithm 1). Returns
@@ -210,21 +231,24 @@ def _speculative_sampling(
210
  ).sum(dim=-1)
211
  elif fsd_div_type == "js":
212
 
213
- m = 0.5 * (cand_probs + target_probs[:, :-1, :]) # Mixture distribution
 
 
 
214
 
215
- # Compute KL(P || M) and KL(Q || M)
216
- kl_pm = kl_div(
217
- m.log().clamp(min=-1e10), # log-probabilities of mixture
218
- cand_probs, # probabilities of candidate
219
- reduction='none'
220
- )
221
- kl_qm = kl_div(
222
- m.log().clamp(min=-1e10), # log-probabilities of mixture
223
- target_probs[:, :-1, :], # probabilities of target
224
- reduction='none'
225
- )
226
 
227
- divs = 0.5 * (kl_pm + kl_qm).sum(dim=-1)
228
 
229
  elif fsd_div_type == "draft_tokens":
230
  draft_token_ids = new_candidate_input_ids # shape: (batch, candidate_length)
@@ -287,7 +311,8 @@ def _assisted_decoding(
287
  assistant_tokenizer: Optional["PreTrainedTokenizerBase"] = None,
288
  tokenizer: Optional["PreTrainedTokenizerBase"] = None,
289
  fsd_threshold: float = 0.0,
290
- fsd_div_type: str = "kl",
 
291
  **model_kwargs,
292
  ) -> Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, torch.LongTensor]:
293
  r"""
@@ -328,6 +353,14 @@ def _assisted_decoding(
328
  output_scores = generation_config.output_scores
329
  output_logits = generation_config.output_logits
330
  return_dict_in_generate = generation_config.return_dict_in_generate
 
 
 
 
 
 
 
 
331
 
332
  # init attention / hidden states / scores tuples
333
  scores = () if (return_dict_in_generate and output_scores) else None
@@ -417,6 +450,10 @@ def _assisted_decoding(
417
  fsd_threshold=fsd_threshold,
418
  fsd_div_type=fsd_div_type,
419
  )
 
 
 
 
420
 
421
  # Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the
422
  # original model logits with the candidate tokens. We can keep the candidate tokens until the first
@@ -435,6 +472,11 @@ def _assisted_decoding(
435
  if is_done_candidate and n_matches == candidate_length:
436
  n_matches -= 1
437
  valid_tokens = selected_tokens[:, : n_matches + 1]
 
 
 
 
 
438
 
439
  # 4. Update variables according to the number of matching assistant tokens. Remember: the token generated
440
  # by the model after the last candidate match is also valid, as it is generated from a correct sequence.
@@ -518,32 +560,69 @@ def _assisted_decoding(
518
  candidate_generator.assistant_model.generation_config.num_assistant_tokens = (
519
  candidate_generator.num_assistant_tokens
520
  )
 
 
 
 
 
 
 
 
521
  if return_dict_in_generate:
522
  cache = None
523
  if any(cache_key in model_kwargs for cache_key in ALL_CACHE_NAMES):
524
  cache_key = next(cache_key for cache_key in ALL_CACHE_NAMES if cache_key in model_kwargs)
525
  cache = model_kwargs[cache_key]
 
526
  if model.config.is_encoder_decoder:
527
- return GenerateEncoderDecoderOutput(
528
- sequences=input_ids,
529
- scores=scores,
530
- logits=raw_logits,
531
- encoder_attentions=encoder_attentions,
532
- encoder_hidden_states=encoder_hidden_states,
533
- decoder_attentions=decoder_attentions,
534
- cross_attentions=cross_attentions,
535
- decoder_hidden_states=decoder_hidden_states,
536
- past_key_values=cache,
537
- )
 
538
  else:
539
- return GenerateDecoderOnlyOutput(
540
- sequences=input_ids,
541
- scores=scores,
542
- logits=raw_logits,
543
- attentions=decoder_attentions,
544
- hidden_states=decoder_hidden_states,
545
- past_key_values=cache,
546
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
547
  else:
548
  return input_ids
549
 
@@ -570,8 +649,12 @@ def generate(
570
  """
571
  # 1. Handle kwargs, `generation_config`, validate them and obtain generation mode
572
  # Extract custom parameters before validation (they're not standard generation config params)
 
 
 
573
  fsd_threshold = kwargs.pop("fsd_threshold", 0.0)
574
- fsd_div_type = kwargs.pop("fsd_div_type", "kl")
 
575
 
576
  generation_mode_kwargs = model._extract_generation_mode_kwargs(
577
  None, # custom_generate
@@ -583,6 +666,7 @@ def generate(
583
  # Add custom FSD parameters to generation_mode_kwargs so they're passed to _assisted_decoding
584
  generation_mode_kwargs["fsd_threshold"] = fsd_threshold
585
  generation_mode_kwargs["fsd_div_type"] = fsd_div_type
 
586
 
587
  # Check length values before updating the config with defaults
588
  has_default_max_length = kwargs.get("max_length") is None and (
@@ -830,47 +914,47 @@ def generate(
830
  # new_candidate_input_ids = candidate_input_ids[:, -candidate_length:]
831
  # correction_term = 0
832
 
833
- # if div_type != 'sd':
834
 
835
- # if div_type == 'kl_div_processed' or div_type == 'js_div_processed' or div_type == 'tv_div_processed':
836
- # epsilon = 1e-10
837
- # q = candidate_logits.softmax(dim=-1)
838
- # p = new_logits[:, :candidate_length, :].softmax(dim=-1) # need to be cropped because M_L logits include logits for ungenerated position
839
 
840
- # q_nonzero = (p > 0).int()
841
- # p_nonzero = (q > 0).int()
842
- # both_nonzero = (q_nonzero & p_nonzero).int()
843
 
844
- # # print(f"nonzero q: {q_nonzero.sum(dim=-1)}")
845
- # # print(f"nonzero p: {p_nonzero.sum(dim=-1)}")
846
- # # print(f"both nonzero: {both_nonzero.sum(dim=-1)}")
847
 
848
- # q = q + epsilon
849
- # p = p + epsilon
850
 
851
- # p = p / p.sum(dim=-1, keepdim=True)
852
- # q = q / q.sum(dim=-1, keepdim=True)
853
 
854
 
855
- # else:
856
- # q = candidate_logits_unprocessed.softmax(dim=-1)
857
- # p = new_logits_unprocessed[:, :candidate_length, :].softmax(dim=-1) # need to be cropped because M_L logits include logits for ungenerated position
858
 
859
- # if len(div_logits_processor) > 0:
860
- # epsilon = 1e-10
861
- # q = q + epsilon
862
- # p = p + epsilon
863
 
864
- # p = p / p.sum(dim=-1, keepdim=True)
865
- # q = q / q.sum(dim=-1, keepdim=True)
866
 
867
- # if div_type == 'kl_div' or div_type == 'kl_div_processed':
868
- # divs = torch.nn.functional.kl_div(torch.log(p), q, reduction='none').sum(dim=-1) # shape = [bs, seq_len]
869
- # elif div_type == 'kl_div_reversed' or div_type == 'kl_div_reversed_processed':
870
- # divs = torch.nn.functional.kl_div(torch.log(q), p, reduction='none').sum(dim=-1) # shape = [bs, seq_len]
871
- # elif div_type == 'js_div' or div_type == 'js_div_processed':
872
- # m = 0.5 * (p + q) # Midpoint distribution
873
- # divs = (0.5 * torch.nn.functional.kl_div(torch.log(p), m, reduction='none') + 0.5 * torch.nn.functional.kl_div(torch.log(q), m, reduction='none')).sum(dim=-1)
874
  # elif div_type == 'tv_div' or div_type == 'tv_div_processed':
875
  # divs = 0.5 * torch.abs(p - q).sum(dim=-1)
876
 
 
62
  attentions: tuple[tuple[torch.FloatTensor]] | None = None
63
  hidden_states: tuple[tuple[torch.FloatTensor]] | None = None
64
  past_key_values: Cache | None = None
65
+ # Draft token acceptance tracking fields (optional for backward compatibility)
66
+ draft_token_acceptance_rate: float | None = None
67
+ total_draft_tokens: int | None = None
68
+ total_accepted_tokens: int | None = None
69
 
70
 
71
  @dataclass
 
81
  cross_attentions: tuple[tuple[torch.FloatTensor]] | None = None
82
  decoder_hidden_states: tuple[tuple[torch.FloatTensor]] | None = None
83
  past_key_values: Cache | None = None
84
+ # Draft token acceptance tracking fields (optional for backward compatibility)
85
+ draft_token_acceptance_rate: float | None = None
86
+ total_draft_tokens: int | None = None
87
+ total_accepted_tokens: int | None = None
88
 
89
 
90
  def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False):
 
123
  """Initialize the custom candidate generator."""
124
  super().__init__(*args, **kwargs)
125
  # Initialize probs list if sklearn is available and confidence threshold is enabled
126
+ # Handle both transformers versions (with and without assistant_generation_config)
127
+ assistant_config = getattr(self, 'assistant_generation_config', None)
128
+ if assistant_config is None:
129
+ # Fallback for transformers versions that don't set assistant_generation_config
130
+ assistant_config = self.assistant_model.generation_config
131
+
132
  if (
133
  is_sklearn_available()
134
+ and hasattr(assistant_config, 'assistant_confidence_threshold')
135
+ and assistant_config.assistant_confidence_threshold
136
  ):
137
  if not hasattr(self, 'probs'):
138
  self.probs = []
 
164
  self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
165
 
166
  # Handle sklearn confidence threshold tracking (if enabled)
167
+ # Handle both transformers versions (with and without assistant_generation_config)
168
+ assistant_config = getattr(self, 'assistant_generation_config', None)
169
+ if assistant_config is None:
170
+ assistant_config = self.assistant_model.generation_config
171
+
172
  if (
173
  is_sklearn_available()
174
+ and hasattr(assistant_config, 'assistant_confidence_threshold')
175
+ and assistant_config.assistant_confidence_threshold
176
  and type(self) is RawLogitsCandidateGenerator
177
  ):
178
  scores_tensor = torch.cat(assistant_output.scores, dim=0)
 
202
  is_done_candidate,
203
  candidate_logits_raw,
204
  fsd_threshold: float = 0.0,
205
+ fsd_div_type: str = "js"
206
  ):
207
  """
208
  Applies sampling as in the speculative decoding paper (https://huggingface.co/papers/2211.17192, algorithm 1). Returns
 
231
  ).sum(dim=-1)
232
  elif fsd_div_type == "js":
233
 
234
+ m = 0.5 * (cand_probs + target_probs[:, :-1, :]) # Midpoint distribution
235
+ divs = (0.5 * torch.nn.functional.kl_div(torch.log(cand_probs), m, reduction='none') + 0.5 * torch.nn.functional.kl_div(torch.log(target_probs[:, :-1, :]), m, reduction='none')).sum(dim=-1)
236
+
237
+ # m = 0.5 * (cand_probs + target_probs[:, :-1, :]) # Mixture distribution
238
 
239
+ # # Compute KL(P || M) and KL(Q || M)
240
+ # kl_pm = kl_div(
241
+ # m.log().clamp(min=-1e10), # log-probabilities of mixture
242
+ # cand_probs, # probabilities of candidate
243
+ # reduction='none'
244
+ # )
245
+ # kl_qm = kl_div(
246
+ # m.log().clamp(min=-1e10), # log-probabilities of mixture
247
+ # target_probs[:, :-1, :], # probabilities of target
248
+ # reduction='none'
249
+ # )
250
 
251
+ # divs = 0.5 * (kl_pm + kl_qm).sum(dim=-1)
252
 
253
  elif fsd_div_type == "draft_tokens":
254
  draft_token_ids = new_candidate_input_ids # shape: (batch, candidate_length)
 
311
  assistant_tokenizer: Optional["PreTrainedTokenizerBase"] = None,
312
  tokenizer: Optional["PreTrainedTokenizerBase"] = None,
313
  fsd_threshold: float = 0.0,
314
+ fsd_div_type: str = "js",
315
+ track_acceptance_metrics: bool = False,
316
  **model_kwargs,
317
  ) -> Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, torch.LongTensor]:
318
  r"""
 
353
  output_scores = generation_config.output_scores
354
  output_logits = generation_config.output_logits
355
  return_dict_in_generate = generation_config.return_dict_in_generate
356
+
357
+ # Track draft token acceptance statistics (only if enabled)
358
+ if track_acceptance_metrics:
359
+ total_draft_tokens = 0
360
+ total_accepted_tokens = 0
361
+ else:
362
+ total_draft_tokens = None
363
+ total_accepted_tokens = None
364
 
365
  # init attention / hidden states / scores tuples
366
  scores = () if (return_dict_in_generate and output_scores) else None
 
450
  fsd_threshold=fsd_threshold,
451
  fsd_div_type=fsd_div_type,
452
  )
453
+ # Track acceptance statistics (only if we have draft tokens and tracking is enabled)
454
+ if track_acceptance_metrics and candidate_length > 0:
455
+ total_draft_tokens += candidate_length
456
+ total_accepted_tokens += n_matches
457
 
458
  # Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the
459
  # original model logits with the candidate tokens. We can keep the candidate tokens until the first
 
472
  if is_done_candidate and n_matches == candidate_length:
473
  n_matches -= 1
474
  valid_tokens = selected_tokens[:, : n_matches + 1]
475
+
476
+ # Track acceptance statistics (for non-sampling case, only if we have draft tokens and tracking is enabled)
477
+ if track_acceptance_metrics and candidate_length > 0:
478
+ total_draft_tokens += candidate_length
479
+ total_accepted_tokens += n_matches
480
 
481
  # 4. Update variables according to the number of matching assistant tokens. Remember: the token generated
482
  # by the model after the last candidate match is also valid, as it is generated from a correct sequence.
 
560
  candidate_generator.assistant_model.generation_config.num_assistant_tokens = (
561
  candidate_generator.num_assistant_tokens
562
  )
563
+ # Calculate draft token acceptance rate (only if tracking is enabled)
564
+ if track_acceptance_metrics:
565
+ acceptance_rate = total_accepted_tokens / total_draft_tokens if total_draft_tokens > 0 else 0.0
566
+ else:
567
+ acceptance_rate = None
568
+ total_draft_tokens = None
569
+ total_accepted_tokens = None
570
+
571
  if return_dict_in_generate:
572
  cache = None
573
  if any(cache_key in model_kwargs for cache_key in ALL_CACHE_NAMES):
574
  cache_key = next(cache_key for cache_key in ALL_CACHE_NAMES if cache_key in model_kwargs)
575
  cache = model_kwargs[cache_key]
576
+ # Build base output dict
577
  if model.config.is_encoder_decoder:
578
+ base_dict = {
579
+ "sequences": input_ids,
580
+ "scores": scores,
581
+ "logits": raw_logits,
582
+ "encoder_attentions": encoder_attentions,
583
+ "encoder_hidden_states": encoder_hidden_states,
584
+ "decoder_attentions": decoder_attentions,
585
+ "cross_attentions": cross_attentions,
586
+ "decoder_hidden_states": decoder_hidden_states,
587
+ "past_key_values": cache,
588
+ }
589
+ output_class = GenerateEncoderDecoderOutput
590
  else:
591
+ base_dict = {
592
+ "sequences": input_ids,
593
+ "scores": scores,
594
+ "logits": raw_logits,
595
+ "attentions": decoder_attentions,
596
+ "hidden_states": decoder_hidden_states,
597
+ "past_key_values": cache,
598
+ }
599
+ output_class = GenerateDecoderOnlyOutput
600
+
601
+ # Try to create output with acceptance rate fields (only if tracking is enabled)
602
+ # If the Hub version doesn't support these fields, create without them
603
+ if track_acceptance_metrics:
604
+ try:
605
+ return output_class(
606
+ **base_dict,
607
+ draft_token_acceptance_rate=acceptance_rate,
608
+ total_draft_tokens=total_draft_tokens,
609
+ total_accepted_tokens=total_accepted_tokens,
610
+ )
611
+ except TypeError:
612
+ # Hub version doesn't support these fields, create without them
613
+ output = output_class(**base_dict)
614
+ # Try to set the fields as attributes (ModelOutput should allow this)
615
+ try:
616
+ output.draft_token_acceptance_rate = acceptance_rate
617
+ output.total_draft_tokens = total_draft_tokens
618
+ output.total_accepted_tokens = total_accepted_tokens
619
+ except Exception:
620
+ # If setting attributes fails, just return without them
621
+ pass
622
+ return output
623
+ else:
624
+ # Tracking disabled, return without metrics
625
+ return output_class(**base_dict)
626
  else:
627
  return input_ids
628
 
 
649
  """
650
  # 1. Handle kwargs, `generation_config`, validate them and obtain generation mode
651
  # Extract custom parameters before validation (they're not standard generation config params)
652
+ # These are used for loading the custom generate function, not for the generation process itself
653
+ custom_generate = kwargs.pop("custom_generate", None)
654
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
655
  fsd_threshold = kwargs.pop("fsd_threshold", 0.0)
656
+ fsd_div_type = kwargs.pop("fsd_div_type", "js")
657
+ track_acceptance_metrics = kwargs.pop("track_acceptance_metrics", False)
658
 
659
  generation_mode_kwargs = model._extract_generation_mode_kwargs(
660
  None, # custom_generate
 
666
  # Add custom FSD parameters to generation_mode_kwargs so they're passed to _assisted_decoding
667
  generation_mode_kwargs["fsd_threshold"] = fsd_threshold
668
  generation_mode_kwargs["fsd_div_type"] = fsd_div_type
669
+ generation_mode_kwargs["track_acceptance_metrics"] = track_acceptance_metrics
670
 
671
  # Check length values before updating the config with defaults
672
  has_default_max_length = kwargs.get("max_length") is None and (
 
914
  # new_candidate_input_ids = candidate_input_ids[:, -candidate_length:]
915
  # correction_term = 0
916
 
917
+ # if div_type != 'sd':
918
 
919
+ # if div_type == 'kl_div_processed' or div_type == 'js_div_processed' or div_type == 'tv_div_processed':
920
+ # epsilon = 1e-10
921
+ # q = candidate_logits.softmax(dim=-1)
922
+ # p = new_logits[:, :candidate_length, :].softmax(dim=-1) # need to be cropped because M_L logits include logits for ungenerated position
923
 
924
+ # q_nonzero = (p > 0).int()
925
+ # p_nonzero = (q > 0).int()
926
+ # both_nonzero = (q_nonzero & p_nonzero).int()
927
 
928
+ # # print(f"nonzero q: {q_nonzero.sum(dim=-1)}")
929
+ # # print(f"nonzero p: {p_nonzero.sum(dim=-1)}")
930
+ # # print(f"both nonzero: {both_nonzero.sum(dim=-1)}")
931
 
932
+ # q = q + epsilon
933
+ # p = p + epsilon
934
 
935
+ # p = p / p.sum(dim=-1, keepdim=True)
936
+ # q = q / q.sum(dim=-1, keepdim=True)
937
 
938
 
939
+ # else:
940
+ # q = candidate_logits_unprocessed.softmax(dim=-1)
941
+ # p = new_logits_unprocessed[:, :candidate_length, :].softmax(dim=-1) # need to be cropped because M_L logits include logits for ungenerated position
942
 
943
+ # if len(div_logits_processor) > 0:
944
+ # epsilon = 1e-10
945
+ # q = q + epsilon
946
+ # p = p + epsilon
947
 
948
+ # p = p / p.sum(dim=-1, keepdim=True)
949
+ # q = q / q.sum(dim=-1, keepdim=True)
950
 
951
+ # if div_type == 'kl_div' or div_type == 'kl_div_processed':
952
+ # divs = torch.nn.functional.kl_div(torch.log(p), q, reduction='none').sum(dim=-1) # shape = [bs, seq_len]
953
+ # elif div_type == 'kl_div_reversed' or div_type == 'kl_div_reversed_processed':
954
+ # divs = torch.nn.functional.kl_div(torch.log(q), p, reduction='none').sum(dim=-1) # shape = [bs, seq_len]
955
+ # elif div_type == 'js_div' or div_type == 'js_div_processed':
956
+ # m = 0.5 * (p + q) # Midpoint distribution
957
+ # divs = (0.5 * torch.nn.functional.kl_div(torch.log(p), m, reduction='none') + 0.5 * torch.nn.functional.kl_div(torch.log(q), m, reduction='none')).sum(dim=-1)
958
  # elif div_type == 'tv_div' or div_type == 'tv_div_processed':
959
  # divs = 0.5 * torch.abs(p - q).sum(dim=-1)
960