Upload folder using huggingface_hub
Browse files- custom_generate/generate.py +152 -68
custom_generate/generate.py
CHANGED
|
@@ -62,6 +62,10 @@ class GenerateDecoderOnlyOutput(ModelOutput):
|
|
| 62 |
attentions: tuple[tuple[torch.FloatTensor]] | None = None
|
| 63 |
hidden_states: tuple[tuple[torch.FloatTensor]] | None = None
|
| 64 |
past_key_values: Cache | None = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
|
| 67 |
@dataclass
|
|
@@ -77,6 +81,10 @@ class GenerateEncoderDecoderOutput(ModelOutput):
|
|
| 77 |
cross_attentions: tuple[tuple[torch.FloatTensor]] | None = None
|
| 78 |
decoder_hidden_states: tuple[tuple[torch.FloatTensor]] | None = None
|
| 79 |
past_key_values: Cache | None = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
|
| 82 |
def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False):
|
|
@@ -115,9 +123,16 @@ class RawLogitsCandidateGenerator(AssistedCandidateGenerator):
|
|
| 115 |
"""Initialize the custom candidate generator."""
|
| 116 |
super().__init__(*args, **kwargs)
|
| 117 |
# Initialize probs list if sklearn is available and confidence threshold is enabled
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
if (
|
| 119 |
is_sklearn_available()
|
| 120 |
-
and
|
|
|
|
| 121 |
):
|
| 122 |
if not hasattr(self, 'probs'):
|
| 123 |
self.probs = []
|
|
@@ -149,9 +164,15 @@ class RawLogitsCandidateGenerator(AssistedCandidateGenerator):
|
|
| 149 |
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
|
| 150 |
|
| 151 |
# Handle sklearn confidence threshold tracking (if enabled)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
if (
|
| 153 |
is_sklearn_available()
|
| 154 |
-
and
|
|
|
|
| 155 |
and type(self) is RawLogitsCandidateGenerator
|
| 156 |
):
|
| 157 |
scores_tensor = torch.cat(assistant_output.scores, dim=0)
|
|
@@ -181,7 +202,7 @@ def _speculative_sampling(
|
|
| 181 |
is_done_candidate,
|
| 182 |
candidate_logits_raw,
|
| 183 |
fsd_threshold: float = 0.0,
|
| 184 |
-
fsd_div_type: str = "
|
| 185 |
):
|
| 186 |
"""
|
| 187 |
Applies sampling as in the speculative decoding paper (https://huggingface.co/papers/2211.17192, algorithm 1). Returns
|
|
@@ -210,21 +231,24 @@ def _speculative_sampling(
|
|
| 210 |
).sum(dim=-1)
|
| 211 |
elif fsd_div_type == "js":
|
| 212 |
|
| 213 |
-
m = 0.5 * (cand_probs + target_probs[:, :-1, :]) #
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
-
# Compute KL(P || M) and KL(Q || M)
|
| 216 |
-
kl_pm = kl_div(
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
)
|
| 221 |
-
kl_qm = kl_div(
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
)
|
| 226 |
|
| 227 |
-
divs = 0.5 * (kl_pm + kl_qm).sum(dim=-1)
|
| 228 |
|
| 229 |
elif fsd_div_type == "draft_tokens":
|
| 230 |
draft_token_ids = new_candidate_input_ids # shape: (batch, candidate_length)
|
|
@@ -287,7 +311,8 @@ def _assisted_decoding(
|
|
| 287 |
assistant_tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
| 288 |
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
| 289 |
fsd_threshold: float = 0.0,
|
| 290 |
-
fsd_div_type: str = "
|
|
|
|
| 291 |
**model_kwargs,
|
| 292 |
) -> Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, torch.LongTensor]:
|
| 293 |
r"""
|
|
@@ -328,6 +353,14 @@ def _assisted_decoding(
|
|
| 328 |
output_scores = generation_config.output_scores
|
| 329 |
output_logits = generation_config.output_logits
|
| 330 |
return_dict_in_generate = generation_config.return_dict_in_generate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
|
| 332 |
# init attention / hidden states / scores tuples
|
| 333 |
scores = () if (return_dict_in_generate and output_scores) else None
|
|
@@ -417,6 +450,10 @@ def _assisted_decoding(
|
|
| 417 |
fsd_threshold=fsd_threshold,
|
| 418 |
fsd_div_type=fsd_div_type,
|
| 419 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
|
| 421 |
# Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the
|
| 422 |
# original model logits with the candidate tokens. We can keep the candidate tokens until the first
|
|
@@ -435,6 +472,11 @@ def _assisted_decoding(
|
|
| 435 |
if is_done_candidate and n_matches == candidate_length:
|
| 436 |
n_matches -= 1
|
| 437 |
valid_tokens = selected_tokens[:, : n_matches + 1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
|
| 439 |
# 4. Update variables according to the number of matching assistant tokens. Remember: the token generated
|
| 440 |
# by the model after the last candidate match is also valid, as it is generated from a correct sequence.
|
|
@@ -518,32 +560,69 @@ def _assisted_decoding(
|
|
| 518 |
candidate_generator.assistant_model.generation_config.num_assistant_tokens = (
|
| 519 |
candidate_generator.num_assistant_tokens
|
| 520 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 521 |
if return_dict_in_generate:
|
| 522 |
cache = None
|
| 523 |
if any(cache_key in model_kwargs for cache_key in ALL_CACHE_NAMES):
|
| 524 |
cache_key = next(cache_key for cache_key in ALL_CACHE_NAMES if cache_key in model_kwargs)
|
| 525 |
cache = model_kwargs[cache_key]
|
|
|
|
| 526 |
if model.config.is_encoder_decoder:
|
| 527 |
-
|
| 528 |
-
sequences
|
| 529 |
-
scores
|
| 530 |
-
logits
|
| 531 |
-
encoder_attentions
|
| 532 |
-
encoder_hidden_states
|
| 533 |
-
decoder_attentions
|
| 534 |
-
cross_attentions
|
| 535 |
-
decoder_hidden_states
|
| 536 |
-
past_key_values
|
| 537 |
-
|
|
|
|
| 538 |
else:
|
| 539 |
-
|
| 540 |
-
sequences
|
| 541 |
-
scores
|
| 542 |
-
logits
|
| 543 |
-
attentions
|
| 544 |
-
hidden_states
|
| 545 |
-
past_key_values
|
| 546 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
else:
|
| 548 |
return input_ids
|
| 549 |
|
|
@@ -570,8 +649,12 @@ def generate(
|
|
| 570 |
"""
|
| 571 |
# 1. Handle kwargs, `generation_config`, validate them and obtain generation mode
|
| 572 |
# Extract custom parameters before validation (they're not standard generation config params)
|
|
|
|
|
|
|
|
|
|
| 573 |
fsd_threshold = kwargs.pop("fsd_threshold", 0.0)
|
| 574 |
-
fsd_div_type = kwargs.pop("fsd_div_type", "
|
|
|
|
| 575 |
|
| 576 |
generation_mode_kwargs = model._extract_generation_mode_kwargs(
|
| 577 |
None, # custom_generate
|
|
@@ -583,6 +666,7 @@ def generate(
|
|
| 583 |
# Add custom FSD parameters to generation_mode_kwargs so they're passed to _assisted_decoding
|
| 584 |
generation_mode_kwargs["fsd_threshold"] = fsd_threshold
|
| 585 |
generation_mode_kwargs["fsd_div_type"] = fsd_div_type
|
|
|
|
| 586 |
|
| 587 |
# Check length values before updating the config with defaults
|
| 588 |
has_default_max_length = kwargs.get("max_length") is None and (
|
|
@@ -830,47 +914,47 @@ def generate(
|
|
| 830 |
# new_candidate_input_ids = candidate_input_ids[:, -candidate_length:]
|
| 831 |
# correction_term = 0
|
| 832 |
|
| 833 |
-
#
|
| 834 |
|
| 835 |
-
#
|
| 836 |
-
#
|
| 837 |
-
#
|
| 838 |
-
#
|
| 839 |
|
| 840 |
-
#
|
| 841 |
-
#
|
| 842 |
-
#
|
| 843 |
|
| 844 |
-
#
|
| 845 |
-
#
|
| 846 |
-
#
|
| 847 |
|
| 848 |
-
#
|
| 849 |
-
#
|
| 850 |
|
| 851 |
-
#
|
| 852 |
-
#
|
| 853 |
|
| 854 |
|
| 855 |
-
#
|
| 856 |
-
#
|
| 857 |
-
#
|
| 858 |
|
| 859 |
-
#
|
| 860 |
-
#
|
| 861 |
-
#
|
| 862 |
-
#
|
| 863 |
|
| 864 |
-
#
|
| 865 |
-
#
|
| 866 |
|
| 867 |
-
#
|
| 868 |
-
#
|
| 869 |
-
#
|
| 870 |
-
#
|
| 871 |
-
#
|
| 872 |
-
#
|
| 873 |
-
#
|
| 874 |
# elif div_type == 'tv_div' or div_type == 'tv_div_processed':
|
| 875 |
# divs = 0.5 * torch.abs(p - q).sum(dim=-1)
|
| 876 |
|
|
|
|
| 62 |
attentions: tuple[tuple[torch.FloatTensor]] | None = None
|
| 63 |
hidden_states: tuple[tuple[torch.FloatTensor]] | None = None
|
| 64 |
past_key_values: Cache | None = None
|
| 65 |
+
# Draft token acceptance tracking fields (optional for backward compatibility)
|
| 66 |
+
draft_token_acceptance_rate: float | None = None
|
| 67 |
+
total_draft_tokens: int | None = None
|
| 68 |
+
total_accepted_tokens: int | None = None
|
| 69 |
|
| 70 |
|
| 71 |
@dataclass
|
|
|
|
| 81 |
cross_attentions: tuple[tuple[torch.FloatTensor]] | None = None
|
| 82 |
decoder_hidden_states: tuple[tuple[torch.FloatTensor]] | None = None
|
| 83 |
past_key_values: Cache | None = None
|
| 84 |
+
# Draft token acceptance tracking fields (optional for backward compatibility)
|
| 85 |
+
draft_token_acceptance_rate: float | None = None
|
| 86 |
+
total_draft_tokens: int | None = None
|
| 87 |
+
total_accepted_tokens: int | None = None
|
| 88 |
|
| 89 |
|
| 90 |
def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False):
|
|
|
|
| 123 |
"""Initialize the custom candidate generator."""
|
| 124 |
super().__init__(*args, **kwargs)
|
| 125 |
# Initialize probs list if sklearn is available and confidence threshold is enabled
|
| 126 |
+
# Handle both transformers versions (with and without assistant_generation_config)
|
| 127 |
+
assistant_config = getattr(self, 'assistant_generation_config', None)
|
| 128 |
+
if assistant_config is None:
|
| 129 |
+
# Fallback for transformers versions that don't set assistant_generation_config
|
| 130 |
+
assistant_config = self.assistant_model.generation_config
|
| 131 |
+
|
| 132 |
if (
|
| 133 |
is_sklearn_available()
|
| 134 |
+
and hasattr(assistant_config, 'assistant_confidence_threshold')
|
| 135 |
+
and assistant_config.assistant_confidence_threshold
|
| 136 |
):
|
| 137 |
if not hasattr(self, 'probs'):
|
| 138 |
self.probs = []
|
|
|
|
| 164 |
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
|
| 165 |
|
| 166 |
# Handle sklearn confidence threshold tracking (if enabled)
|
| 167 |
+
# Handle both transformers versions (with and without assistant_generation_config)
|
| 168 |
+
assistant_config = getattr(self, 'assistant_generation_config', None)
|
| 169 |
+
if assistant_config is None:
|
| 170 |
+
assistant_config = self.assistant_model.generation_config
|
| 171 |
+
|
| 172 |
if (
|
| 173 |
is_sklearn_available()
|
| 174 |
+
and hasattr(assistant_config, 'assistant_confidence_threshold')
|
| 175 |
+
and assistant_config.assistant_confidence_threshold
|
| 176 |
and type(self) is RawLogitsCandidateGenerator
|
| 177 |
):
|
| 178 |
scores_tensor = torch.cat(assistant_output.scores, dim=0)
|
|
|
|
| 202 |
is_done_candidate,
|
| 203 |
candidate_logits_raw,
|
| 204 |
fsd_threshold: float = 0.0,
|
| 205 |
+
fsd_div_type: str = "js"
|
| 206 |
):
|
| 207 |
"""
|
| 208 |
Applies sampling as in the speculative decoding paper (https://huggingface.co/papers/2211.17192, algorithm 1). Returns
|
|
|
|
| 231 |
).sum(dim=-1)
|
| 232 |
elif fsd_div_type == "js":
|
| 233 |
|
| 234 |
+
m = 0.5 * (cand_probs + target_probs[:, :-1, :]) # Midpoint distribution
|
| 235 |
+
divs = (0.5 * torch.nn.functional.kl_div(torch.log(cand_probs), m, reduction='none') + 0.5 * torch.nn.functional.kl_div(torch.log(target_probs[:, :-1, :]), m, reduction='none')).sum(dim=-1)
|
| 236 |
+
|
| 237 |
+
# m = 0.5 * (cand_probs + target_probs[:, :-1, :]) # Mixture distribution
|
| 238 |
|
| 239 |
+
# # Compute KL(P || M) and KL(Q || M)
|
| 240 |
+
# kl_pm = kl_div(
|
| 241 |
+
# m.log().clamp(min=-1e10), # log-probabilities of mixture
|
| 242 |
+
# cand_probs, # probabilities of candidate
|
| 243 |
+
# reduction='none'
|
| 244 |
+
# )
|
| 245 |
+
# kl_qm = kl_div(
|
| 246 |
+
# m.log().clamp(min=-1e10), # log-probabilities of mixture
|
| 247 |
+
# target_probs[:, :-1, :], # probabilities of target
|
| 248 |
+
# reduction='none'
|
| 249 |
+
# )
|
| 250 |
|
| 251 |
+
# divs = 0.5 * (kl_pm + kl_qm).sum(dim=-1)
|
| 252 |
|
| 253 |
elif fsd_div_type == "draft_tokens":
|
| 254 |
draft_token_ids = new_candidate_input_ids # shape: (batch, candidate_length)
|
|
|
|
| 311 |
assistant_tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
| 312 |
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
| 313 |
fsd_threshold: float = 0.0,
|
| 314 |
+
fsd_div_type: str = "js",
|
| 315 |
+
track_acceptance_metrics: bool = False,
|
| 316 |
**model_kwargs,
|
| 317 |
) -> Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, torch.LongTensor]:
|
| 318 |
r"""
|
|
|
|
| 353 |
output_scores = generation_config.output_scores
|
| 354 |
output_logits = generation_config.output_logits
|
| 355 |
return_dict_in_generate = generation_config.return_dict_in_generate
|
| 356 |
+
|
| 357 |
+
# Track draft token acceptance statistics (only if enabled)
|
| 358 |
+
if track_acceptance_metrics:
|
| 359 |
+
total_draft_tokens = 0
|
| 360 |
+
total_accepted_tokens = 0
|
| 361 |
+
else:
|
| 362 |
+
total_draft_tokens = None
|
| 363 |
+
total_accepted_tokens = None
|
| 364 |
|
| 365 |
# init attention / hidden states / scores tuples
|
| 366 |
scores = () if (return_dict_in_generate and output_scores) else None
|
|
|
|
| 450 |
fsd_threshold=fsd_threshold,
|
| 451 |
fsd_div_type=fsd_div_type,
|
| 452 |
)
|
| 453 |
+
# Track acceptance statistics (only if we have draft tokens and tracking is enabled)
|
| 454 |
+
if track_acceptance_metrics and candidate_length > 0:
|
| 455 |
+
total_draft_tokens += candidate_length
|
| 456 |
+
total_accepted_tokens += n_matches
|
| 457 |
|
| 458 |
# Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the
|
| 459 |
# original model logits with the candidate tokens. We can keep the candidate tokens until the first
|
|
|
|
| 472 |
if is_done_candidate and n_matches == candidate_length:
|
| 473 |
n_matches -= 1
|
| 474 |
valid_tokens = selected_tokens[:, : n_matches + 1]
|
| 475 |
+
|
| 476 |
+
# Track acceptance statistics (for non-sampling case, only if we have draft tokens and tracking is enabled)
|
| 477 |
+
if track_acceptance_metrics and candidate_length > 0:
|
| 478 |
+
total_draft_tokens += candidate_length
|
| 479 |
+
total_accepted_tokens += n_matches
|
| 480 |
|
| 481 |
# 4. Update variables according to the number of matching assistant tokens. Remember: the token generated
|
| 482 |
# by the model after the last candidate match is also valid, as it is generated from a correct sequence.
|
|
|
|
| 560 |
candidate_generator.assistant_model.generation_config.num_assistant_tokens = (
|
| 561 |
candidate_generator.num_assistant_tokens
|
| 562 |
)
|
| 563 |
+
# Calculate draft token acceptance rate (only if tracking is enabled)
|
| 564 |
+
if track_acceptance_metrics:
|
| 565 |
+
acceptance_rate = total_accepted_tokens / total_draft_tokens if total_draft_tokens > 0 else 0.0
|
| 566 |
+
else:
|
| 567 |
+
acceptance_rate = None
|
| 568 |
+
total_draft_tokens = None
|
| 569 |
+
total_accepted_tokens = None
|
| 570 |
+
|
| 571 |
if return_dict_in_generate:
|
| 572 |
cache = None
|
| 573 |
if any(cache_key in model_kwargs for cache_key in ALL_CACHE_NAMES):
|
| 574 |
cache_key = next(cache_key for cache_key in ALL_CACHE_NAMES if cache_key in model_kwargs)
|
| 575 |
cache = model_kwargs[cache_key]
|
| 576 |
+
# Build base output dict
|
| 577 |
if model.config.is_encoder_decoder:
|
| 578 |
+
base_dict = {
|
| 579 |
+
"sequences": input_ids,
|
| 580 |
+
"scores": scores,
|
| 581 |
+
"logits": raw_logits,
|
| 582 |
+
"encoder_attentions": encoder_attentions,
|
| 583 |
+
"encoder_hidden_states": encoder_hidden_states,
|
| 584 |
+
"decoder_attentions": decoder_attentions,
|
| 585 |
+
"cross_attentions": cross_attentions,
|
| 586 |
+
"decoder_hidden_states": decoder_hidden_states,
|
| 587 |
+
"past_key_values": cache,
|
| 588 |
+
}
|
| 589 |
+
output_class = GenerateEncoderDecoderOutput
|
| 590 |
else:
|
| 591 |
+
base_dict = {
|
| 592 |
+
"sequences": input_ids,
|
| 593 |
+
"scores": scores,
|
| 594 |
+
"logits": raw_logits,
|
| 595 |
+
"attentions": decoder_attentions,
|
| 596 |
+
"hidden_states": decoder_hidden_states,
|
| 597 |
+
"past_key_values": cache,
|
| 598 |
+
}
|
| 599 |
+
output_class = GenerateDecoderOnlyOutput
|
| 600 |
+
|
| 601 |
+
# Try to create output with acceptance rate fields (only if tracking is enabled)
|
| 602 |
+
# If the Hub version doesn't support these fields, create without them
|
| 603 |
+
if track_acceptance_metrics:
|
| 604 |
+
try:
|
| 605 |
+
return output_class(
|
| 606 |
+
**base_dict,
|
| 607 |
+
draft_token_acceptance_rate=acceptance_rate,
|
| 608 |
+
total_draft_tokens=total_draft_tokens,
|
| 609 |
+
total_accepted_tokens=total_accepted_tokens,
|
| 610 |
+
)
|
| 611 |
+
except TypeError:
|
| 612 |
+
# Hub version doesn't support these fields, create without them
|
| 613 |
+
output = output_class(**base_dict)
|
| 614 |
+
# Try to set the fields as attributes (ModelOutput should allow this)
|
| 615 |
+
try:
|
| 616 |
+
output.draft_token_acceptance_rate = acceptance_rate
|
| 617 |
+
output.total_draft_tokens = total_draft_tokens
|
| 618 |
+
output.total_accepted_tokens = total_accepted_tokens
|
| 619 |
+
except Exception:
|
| 620 |
+
# If setting attributes fails, just return without them
|
| 621 |
+
pass
|
| 622 |
+
return output
|
| 623 |
+
else:
|
| 624 |
+
# Tracking disabled, return without metrics
|
| 625 |
+
return output_class(**base_dict)
|
| 626 |
else:
|
| 627 |
return input_ids
|
| 628 |
|
|
|
|
| 649 |
"""
|
| 650 |
# 1. Handle kwargs, `generation_config`, validate them and obtain generation mode
|
| 651 |
# Extract custom parameters before validation (they're not standard generation config params)
|
| 652 |
+
# These are used for loading the custom generate function, not for the generation process itself
|
| 653 |
+
custom_generate = kwargs.pop("custom_generate", None)
|
| 654 |
+
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
| 655 |
fsd_threshold = kwargs.pop("fsd_threshold", 0.0)
|
| 656 |
+
fsd_div_type = kwargs.pop("fsd_div_type", "js")
|
| 657 |
+
track_acceptance_metrics = kwargs.pop("track_acceptance_metrics", False)
|
| 658 |
|
| 659 |
generation_mode_kwargs = model._extract_generation_mode_kwargs(
|
| 660 |
None, # custom_generate
|
|
|
|
| 666 |
# Add custom FSD parameters to generation_mode_kwargs so they're passed to _assisted_decoding
|
| 667 |
generation_mode_kwargs["fsd_threshold"] = fsd_threshold
|
| 668 |
generation_mode_kwargs["fsd_div_type"] = fsd_div_type
|
| 669 |
+
generation_mode_kwargs["track_acceptance_metrics"] = track_acceptance_metrics
|
| 670 |
|
| 671 |
# Check length values before updating the config with defaults
|
| 672 |
has_default_max_length = kwargs.get("max_length") is None and (
|
|
|
|
| 914 |
# new_candidate_input_ids = candidate_input_ids[:, -candidate_length:]
|
| 915 |
# correction_term = 0
|
| 916 |
|
| 917 |
+
# if div_type != 'sd':
|
| 918 |
|
| 919 |
+
# if div_type == 'kl_div_processed' or div_type == 'js_div_processed' or div_type == 'tv_div_processed':
|
| 920 |
+
# epsilon = 1e-10
|
| 921 |
+
# q = candidate_logits.softmax(dim=-1)
|
| 922 |
+
# p = new_logits[:, :candidate_length, :].softmax(dim=-1) # need to be cropped because M_L logits include logits for ungenerated position
|
| 923 |
|
| 924 |
+
# q_nonzero = (p > 0).int()
|
| 925 |
+
# p_nonzero = (q > 0).int()
|
| 926 |
+
# both_nonzero = (q_nonzero & p_nonzero).int()
|
| 927 |
|
| 928 |
+
# # print(f"nonzero q: {q_nonzero.sum(dim=-1)}")
|
| 929 |
+
# # print(f"nonzero p: {p_nonzero.sum(dim=-1)}")
|
| 930 |
+
# # print(f"both nonzero: {both_nonzero.sum(dim=-1)}")
|
| 931 |
|
| 932 |
+
# q = q + epsilon
|
| 933 |
+
# p = p + epsilon
|
| 934 |
|
| 935 |
+
# p = p / p.sum(dim=-1, keepdim=True)
|
| 936 |
+
# q = q / q.sum(dim=-1, keepdim=True)
|
| 937 |
|
| 938 |
|
| 939 |
+
# else:
|
| 940 |
+
# q = candidate_logits_unprocessed.softmax(dim=-1)
|
| 941 |
+
# p = new_logits_unprocessed[:, :candidate_length, :].softmax(dim=-1) # need to be cropped because M_L logits include logits for ungenerated position
|
| 942 |
|
| 943 |
+
# if len(div_logits_processor) > 0:
|
| 944 |
+
# epsilon = 1e-10
|
| 945 |
+
# q = q + epsilon
|
| 946 |
+
# p = p + epsilon
|
| 947 |
|
| 948 |
+
# p = p / p.sum(dim=-1, keepdim=True)
|
| 949 |
+
# q = q / q.sum(dim=-1, keepdim=True)
|
| 950 |
|
| 951 |
+
# if div_type == 'kl_div' or div_type == 'kl_div_processed':
|
| 952 |
+
# divs = torch.nn.functional.kl_div(torch.log(p), q, reduction='none').sum(dim=-1) # shape = [bs, seq_len]
|
| 953 |
+
# elif div_type == 'kl_div_reversed' or div_type == 'kl_div_reversed_processed':
|
| 954 |
+
# divs = torch.nn.functional.kl_div(torch.log(q), p, reduction='none').sum(dim=-1) # shape = [bs, seq_len]
|
| 955 |
+
# elif div_type == 'js_div' or div_type == 'js_div_processed':
|
| 956 |
+
# m = 0.5 * (p + q) # Midpoint distribution
|
| 957 |
+
# divs = (0.5 * torch.nn.functional.kl_div(torch.log(p), m, reduction='none') + 0.5 * torch.nn.functional.kl_div(torch.log(q), m, reduction='none')).sum(dim=-1)
|
| 958 |
# elif div_type == 'tv_div' or div_type == 'tv_div_processed':
|
| 959 |
# divs = 0.5 * torch.abs(p - q).sum(dim=-1)
|
| 960 |
|