maxholsman's picture
Upload folder using huggingface_hub
4a9570a verified
# coding=utf-8
# Custom generate function for fuzzy speculative decoding
# Based on transformers.generation.utils with modifications for custom acceptance/rejection logic
import copy
import inspect
import warnings
from collections.abc import Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Union
import torch
import torch.distributed as dist
from torch import nn
from torch.nn.functional import kl_div, log_softmax
from transformers.cache_utils import Cache
from transformers.generation.candidate_generator import (
AssistedCandidateGenerator,
_prepare_attention_mask,
_prepare_token_type_ids,
)
from transformers.generation.configuration_utils import GenerationConfig, GenerationMode
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList
from transformers.utils import ModelOutput, is_sklearn_available
if is_sklearn_available():
from sklearn.metrics import roc_curve
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.generation.streamers import BaseStreamer
# Variable names used to hold the cache at generation time
ALL_CACHE_NAMES = [
"past_key_values", # default
"cache_params", # mamba-based models
"state", # rwkv
"mems", # xlnet
"past_buckets_states", # reformer
]
GENERATION_MODES_MAPPING = {
GenerationMode.SAMPLE: "_sample",
GenerationMode.GREEDY_SEARCH: "_sample",
GenerationMode.BEAM_SEARCH: "_beam_search",
GenerationMode.BEAM_SAMPLE: "_beam_search",
GenerationMode.ASSISTED_GENERATION: "_assisted_decoding",
}
@dataclass
class GenerateDecoderOnlyOutput(ModelOutput):
"""Outputs of decoder-only generation models, when using non-beam methods."""
sequences: torch.LongTensor
scores: tuple[torch.FloatTensor] | None = None
logits: tuple[torch.FloatTensor] | None = None
attentions: tuple[tuple[torch.FloatTensor]] | None = None
hidden_states: tuple[tuple[torch.FloatTensor]] | None = None
past_key_values: Cache | None = None
@dataclass
class GenerateEncoderDecoderOutput(ModelOutput):
"""Outputs of encoder-decoder generation models, when using non-beam methods."""
sequences: torch.LongTensor
scores: tuple[torch.FloatTensor] | None = None
logits: tuple[torch.FloatTensor] | None = None
encoder_attentions: tuple[torch.FloatTensor] | None = None
encoder_hidden_states: tuple[torch.FloatTensor] | None = None
decoder_attentions: tuple[tuple[torch.FloatTensor]] | None = None
cross_attentions: tuple[tuple[torch.FloatTensor]] | None = None
decoder_hidden_states: tuple[tuple[torch.FloatTensor]] | None = None
past_key_values: Cache | None = None
def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False):
"""
Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple
where each member corresponds to a single generated token.
"""
# Retrocompatibility: in our generation functions, the first iteration includes the attention/hidden states for the
# prompt.
if len(outputs) == 0:
new_tuple = ()
for layer in new_outputs:
last_dim_size = cur_len if is_decoder_attention else layer.shape[-1]
new_tuple += (layer[..., :cur_len, :last_dim_size],)
outputs += (new_tuple,)
# The first iteration contains the prompt + 1 generated token, let's update the length variables accordingly
cur_len += 1
added_len -= cur_len
for i in range(added_len):
new_tuple = ()
for layer in new_outputs:
last_dim_size = cur_len + i if is_decoder_attention else layer.shape[-1]
new_tuple += (layer[..., i : i + 1, :last_dim_size],)
outputs += (new_tuple,)
return outputs
class RawLogitsCandidateGenerator(AssistedCandidateGenerator):
"""
Custom candidate generator that returns both processed and raw logits from the assistant model.
Extends AssistedCandidateGenerator to support returning raw logits when output_logits=True.
"""
def __init__(self, *args, **kwargs):
"""Initialize the custom candidate generator."""
super().__init__(*args, **kwargs)
# Initialize probs list if sklearn is available and confidence threshold is enabled
if (
is_sklearn_available()
and self.assistant_generation_config.assistant_confidence_threshold
):
if not hasattr(self, 'probs'):
self.probs = []
if not hasattr(self, 'matches'):
self.matches = []
def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, torch.FloatTensor | None, torch.FloatTensor | None]:
"""
Fetches the candidates to be tried for the current input.
Returns: (candidate_ids, candidate_logits_processed, candidate_logits_raw)
- candidate_logits_processed: Processed logits (scores) from assistant model
- candidate_logits_raw: Raw logits from assistant model (None if output_logits=False)
"""
input_ids = input_ids.to(self.assistant_model.device)
# Calculate new tokens to generate
min_new_tokens, max_new_tokens = self._calculate_new_tokens(input_ids)
if max_new_tokens == 0:
return input_ids, None, None
# Update past key values and masks
self._update_past_and_masks(input_ids)
# Generate candidates
generation_args = self._prepare_generation_args(input_ids, min_new_tokens, max_new_tokens)
candidate_ids, candidate_logits_processed, candidate_logits_raw = self._generate_candidates(generation_args)
return candidate_ids, candidate_logits_processed, candidate_logits_raw
def _generate_candidates(self, generation_args: dict) -> tuple[torch.LongTensor, torch.FloatTensor | None, torch.FloatTensor | None]:
"""Generate candidate sequences using the assistant model, returning both processed and raw logits."""
assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs)
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
# Handle sklearn confidence threshold tracking (if enabled)
if (
is_sklearn_available()
and self.assistant_generation_config.assistant_confidence_threshold
and type(self) is RawLogitsCandidateGenerator
):
scores_tensor = torch.cat(assistant_output.scores, dim=0)
scores_softmax = torch.softmax(scores_tensor, dim=-1)
ids = assistant_output.sequences[-1, -len(assistant_output.scores) :]
p = scores_softmax[range(len(ids)), ids]
self.probs.extend(p.tolist())
# Extract processed logits (scores) - always available
candidate_logits_processed = torch.stack(assistant_output.scores, dim=1)
candidate_ids = assistant_output.sequences
# Extract raw logits if available (when output_logits=True)
candidate_logits_raw = None
if self.generation_config.output_logits and hasattr(assistant_output, 'logits') and assistant_output.logits is not None:
candidate_logits_raw = torch.stack(assistant_output.logits, dim=1)
return candidate_ids, candidate_logits_processed, candidate_logits_raw
def _speculative_sampling(
candidate_input_ids,
candidate_logits,
candidate_length,
new_logits,
next_token_logits,
is_done_candidate,
candidate_logits_raw,
fsd_threshold: float = 0.0,
fsd_div_type: str = "kl"
):
"""
Applies sampling as in the speculative decoding paper (https://huggingface.co/papers/2211.17192, algorithm 1). Returns
the selected tokens, as well as the number of candidate matches.
NOTE: Unless otherwise stated, the variable names match those in the paper.
"""
new_candidate_input_ids = candidate_input_ids[:, -candidate_length:]
# Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens
# selected by the assistant, respectively.
q = candidate_logits.softmax(dim=-1)
q_i = q[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
p = new_logits.softmax(dim=-1)
p_i = p[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
probability_ratio = p_i / q_i
target_probs = next_token_logits.softmax(dim=-1)
cand_probs = candidate_logits_raw.softmax(dim=-1)
if fsd_div_type == "kl":
divs = kl_div(
cand_probs.log().clamp(min=-1e10), # log-probabilities of candidate distribution
target_probs[:, :-1, :], # probabilities of target distribution
reduction='none'
).sum(dim=-1)
elif fsd_div_type == "js":
m = 0.5 * (cand_probs + target_probs[:, :-1, :]) # Mixture distribution
# Compute KL(P || M) and KL(Q || M)
kl_pm = kl_div(
m.log().clamp(min=-1e10), # log-probabilities of mixture
cand_probs, # probabilities of candidate
reduction='none'
)
kl_qm = kl_div(
m.log().clamp(min=-1e10), # log-probabilities of mixture
target_probs[:, :-1, :], # probabilities of target
reduction='none'
)
divs = 0.5 * (kl_pm + kl_qm).sum(dim=-1)
elif fsd_div_type == "draft_tokens":
draft_token_ids = new_candidate_input_ids # shape: (batch, candidate_length)
draft_token_probs_candidate = cand_probs[:, torch.arange(candidate_length), draft_token_ids].squeeze(0, 1)
draft_token_probs_target = target_probs[:, :-1, :][:, torch.arange(candidate_length), draft_token_ids].squeeze(0,
1)
divs = (draft_token_probs_candidate - draft_token_probs_target).abs().sum(dim=-1)
else:
raise ValueError(f"Invalid fsd_div_type: {fsd_div_type}")
# print(f"divs: {divs}")
is_accepted_fsd = divs <= fsd_threshold
# When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller
# than the model probability for the same token"), keep the token. Otherwise reject with p = 1 - probability_ratio
# (= keep with p = probability_ratio). Keep all the tokens until the first rejection
r_i = torch.rand_like(probability_ratio)
is_accepted_sd = r_i <= probability_ratio
is_accepted = is_accepted_fsd | is_accepted_sd
n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1
# print(f"is_accepted_fsd: {is_accepted_fsd}\n is_accepted_sd: {is_accepted_sd}\n is_accepted: {is_accepted}")
# Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior)
if is_done_candidate and n_matches == candidate_length:
# Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model
# due to acceptance on EOS we fix `n_matches`
n_matches -= 1
valid_tokens = new_candidate_input_ids[:, : n_matches + 1]
else:
# Next token selection: if there is a rejection, adjust the distribution from the main model before sampling.
gamma = candidate_logits.shape[1]
p_n_plus_1 = p[:, n_matches, :]
if n_matches < gamma:
q_n_plus_1 = q[:, n_matches, :]
p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0)
p_prime.div_(p_prime.sum())
else:
p_prime = p_n_plus_1
t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :]
# The selected tokens include the matches (if any) plus the next sampled tokens
if n_matches > 0:
valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1)
else:
valid_tokens = t
return valid_tokens, n_matches
def _assisted_decoding(
model,
input_ids: torch.LongTensor,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool = False,
streamer: Optional["BaseStreamer"] = None,
inputs_tensor: torch.FloatTensor | None = None,
assistant_model: Optional["PreTrainedModel"] = None,
assistant_tokenizer: Optional["PreTrainedTokenizerBase"] = None,
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
fsd_threshold: float = 0.0,
fsd_div_type: str = "kl",
**model_kwargs,
) -> Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, torch.LongTensor]:
r"""
Generates sequences of token ids for models with a language modeling head using **greedy decoding** or
**sample** (depending on `do_sample`), assisted by candidate sequences. Assisted generation is an example of a
candidate decoding strategy. Can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text
models.
"""
# The cache must be dynamic for assisted generation, and the check must happen AFTER preparing cache
if not model_kwargs["use_cache"]:
raise ValueError("assisted generate requires `use_cache=True`")
if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"] or (
"past_key_values" in model_kwargs
and hasattr(model_kwargs["past_key_values"], "layers")
and any(getattr(l, "is_compileable", False) for l in model_kwargs["past_key_values"].layers)
):
raise ValueError("assisted generate is not supported with Static cache classes`")
# Create custom candidate generator that supports raw logits
# Set output_logits based on generation_config (don't force it)
if assistant_model is None:
raise ValueError("assistant_model is required for assisted generation")
generation_config.output_logits = True
candidate_generator = RawLogitsCandidateGenerator(
input_ids=input_ids,
assistant_model=assistant_model,
generation_config=generation_config,
model_kwargs=model_kwargs,
inputs_tensor=inputs_tensor,
logits_processor=logits_processor,
)
# init values
do_sample = generation_config.do_sample
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and model.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
# keep track of which sequences are already finished
batch_size, cur_len = input_ids.shape[:2]
if batch_size > 1:
raise ValueError("assisted generate is only supported for batch_size = 1")
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = model._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
this_peer_finished = False
is_first_iteration = True # to preserve the same API in the output as other generation methods
while model._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
cur_len = input_ids.shape[1]
# 1. Fetch candidate sequences from a `CandidateGenerator` and move to the correct device
candidate_input_ids, candidate_logits, candidate_logits_raw = candidate_generator.get_candidates(input_ids)
candidate_input_ids = candidate_input_ids.to(model.device)
if candidate_logits is not None:
candidate_logits = candidate_logits.to(model.device)
if candidate_logits_raw is not None:
candidate_logits_raw = candidate_logits_raw.to(model.device)
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
is_done_candidate = stopping_criteria(candidate_input_ids, None)
# 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
# `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct,
# we use this forward pass to also pick the subsequent logits in the original model.
# 2.1. Prepare the model inputs
candidate_kwargs = copy.copy(model_kwargs)
candidate_kwargs = _prepare_attention_mask(
candidate_kwargs, candidate_input_ids.shape[1], model.config.is_encoder_decoder
)
candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
if "cache_position" in candidate_kwargs:
candidate_kwargs["cache_position"] = torch.cat(
(
candidate_kwargs["cache_position"],
torch.arange(cur_len, cur_len + candidate_length, device=input_ids.device, dtype=torch.long),
),
dim=0,
)
model_inputs = model.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
if "logits_to_keep" in model_inputs:
model_inputs["logits_to_keep"] = candidate_length + 1
# 2.2. Run a forward pass on the candidate sequence
outputs = model(**model_inputs)
# 2.3. Process the new logits
# .float() is needed to retain precision for later logits manipulations
new_logits = outputs.logits[:, -candidate_length - 1 :].to(
dtype=torch.float32, device=input_ids.device
) # excludes the input prompt if present
next_token_logits = new_logits.clone()
if len(logits_processor) > 0:
for i in range(candidate_length + 1):
new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
# 3. Select the accepted tokens. There are two possible cases:
# Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding)
# 👉 Apply algorithm 1 from the speculative decoding paper (https://huggingface.co/papers/2211.17192).
if do_sample and candidate_logits is not None:
valid_tokens, n_matches = _speculative_sampling(
candidate_input_ids,
candidate_logits,
candidate_length,
new_logits,
next_token_logits,
is_done_candidate,
candidate_logits_raw=candidate_logits_raw,
fsd_threshold=fsd_threshold,
fsd_div_type=fsd_div_type,
)
# Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the
# original model logits with the candidate tokens. We can keep the candidate tokens until the first
# mismatch, or until the max length is reached.
else:
if do_sample:
probs = new_logits.softmax(dim=-1)
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
else:
selected_tokens = new_logits.argmax(dim=-1)
candidate_new_tokens = candidate_input_ids[:, cur_len:]
n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()
# Ensure we don't generate beyond max_len or an EOS token
if is_done_candidate and n_matches == candidate_length:
n_matches -= 1
valid_tokens = selected_tokens[:, : n_matches + 1]
# 4. Update variables according to the number of matching assistant tokens. Remember: the token generated
# by the model after the last candidate match is also valid, as it is generated from a correct sequence.
# Because of this last token, assisted generation search reduces to a normal greedy search/sample if there
# is no match.
# 4.1. Get the valid continuation, after the matching tokens
input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
if streamer is not None:
streamer.put(valid_tokens.cpu())
new_cur_len = input_ids.shape[1]
# 4.2. Discard past key values relative to unused assistant tokens
outputs.past_key_values.crop(new_cur_len - 1)
# 5. Update the candidate generation strategy if needed
candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches)
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = model._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=model.config.is_encoder_decoder,
num_new_tokens=n_matches + 1,
)
if synced_gpus and this_peer_finished:
continue
# Store scores, attentions and hidden_states when required
# Assistant: modified to append one tuple element per token, as in the other generation methods.
if return_dict_in_generate:
newly_added_length = n_matches + 1
if output_scores:
scores += tuple(new_logits[:, i, :] for i in range(newly_added_length))
if output_logits:
raw_logits += tuple(next_token_logits[:, i, :] for i in range(newly_added_length))
newly_added_length = new_cur_len if is_first_iteration else newly_added_length
if output_attentions:
if model.config.is_encoder_decoder:
cross_attentions = _split_model_outputs(
cross_attentions, outputs.cross_attentions, cur_len, newly_added_length
)
decoder_attentions = _split_model_outputs(
decoder_attentions,
outputs.decoder_attentions,
cur_len,
newly_added_length,
is_decoder_attention=True,
)
# some (V)LLMs have hard requirement on SDPA and thus never return attn
elif outputs.attentions[0] is not None:
decoder_attentions = _split_model_outputs(
decoder_attentions,
outputs.attentions,
cur_len,
newly_added_length,
is_decoder_attention=True,
)
if output_hidden_states:
if model.config.is_encoder_decoder:
decoder_hidden_states = _split_model_outputs(
decoder_hidden_states, outputs.decoder_hidden_states, cur_len, newly_added_length
)
else:
decoder_hidden_states = _split_model_outputs(
decoder_hidden_states, outputs.hidden_states, cur_len, newly_added_length
)
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0
is_first_iteration = False
if streamer is not None:
streamer.end()
if (
hasattr(candidate_generator, "assistant_model")
and candidate_generator.assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic"
):
candidate_generator.assistant_model.generation_config.num_assistant_tokens = (
candidate_generator.num_assistant_tokens
)
if return_dict_in_generate:
cache = None
if any(cache_key in model_kwargs for cache_key in ALL_CACHE_NAMES):
cache_key = next(cache_key for cache_key in ALL_CACHE_NAMES if cache_key in model_kwargs)
cache = model_kwargs[cache_key]
if model.config.is_encoder_decoder:
return GenerateEncoderDecoderOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=cache,
)
else:
return GenerateDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=cache,
)
else:
return input_ids
def generate(
model,
inputs: torch.Tensor | None = None,
generation_config: GenerationConfig | None = None,
logits_processor: LogitsProcessorList | None = None,
stopping_criteria: StoppingCriteriaList | None = None,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], list[int]] | None = None,
synced_gpus: bool | None = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
negative_prompt_ids: torch.Tensor | None = None,
negative_prompt_attention_mask: torch.Tensor | None = None,
**kwargs,
) -> GenerateDecoderOnlyOutput | GenerateEncoderDecoderOutput | torch.LongTensor:
r"""
Generates sequences of token ids for models with a language modeling head.
This is a custom generate function that replaces the standard one. It supports all standard generation modes
and includes custom speculative decoding acceptance/rejection logic.
"""
# 1. Handle kwargs, `generation_config`, validate them and obtain generation mode
# Extract custom parameters before validation (they're not standard generation config params)
fsd_threshold = kwargs.pop("fsd_threshold", 0.0)
fsd_div_type = kwargs.pop("fsd_div_type", "kl")
generation_mode_kwargs = model._extract_generation_mode_kwargs(
None, # custom_generate
kwargs,
synced_gpus,
assistant_model,
streamer,
)
# Add custom FSD parameters to generation_mode_kwargs so they're passed to _assisted_decoding
generation_mode_kwargs["fsd_threshold"] = fsd_threshold
generation_mode_kwargs["fsd_div_type"] = fsd_div_type
# Check length values before updating the config with defaults
has_default_max_length = kwargs.get("max_length") is None and (
generation_config is None or generation_config.max_length is None
)
has_default_min_length = kwargs.get("min_length") is None and (
generation_config is None or generation_config.min_length is None
)
generation_config, model_kwargs = model._prepare_generation_config(generation_config, **kwargs)
generation_mode = generation_config.get_generation_mode(assistant_model)
# type() required to access the unbound class-level method
decoding_method = getattr(type(model), GENERATION_MODES_MAPPING[generation_mode])
model._validate_model_kwargs(model_kwargs.copy())
model._validate_generation_mode(generation_mode, generation_config, generation_mode_kwargs)
# 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
accepts_attention_mask = "attention_mask" in set(inspect.signature(model.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
# 3. Define model inputs
inputs_tensor, model_input_name, model_kwargs = model._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
# Some generation modes (e.g. assisted) need `inputs_tensor` to rerun encoder.forward()
if "inputs_tensor" in inspect.signature(decoding_method).parameters.keys():
generation_mode_kwargs["inputs_tensor"] = inputs_tensor
batch_size = inputs_tensor.shape[0]
device = inputs_tensor.device
model._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
# decoder-only models must use left-padding for batched generation.
if not model.config.is_encoder_decoder:
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
if (
generation_config._pad_token_tensor is not None
and batch_size > 1
and len(inputs_tensor.shape) == 2
and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0
):
import logging
logger = logging.get_logger(__name__)
logger.warning(
"A decoder-only architecture is being used, but right-padding was detected! For correct "
"generation results, please set `padding_side='left'` when initializing the tokenizer."
)
# 4. Define other model kwargs
# decoder-only models with inputs_embeds forwarding must use caching
if not model.config.is_encoder_decoder and model_input_name == "inputs_embeds":
generation_config.use_cache = True
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = model._prepare_attention_mask_for_generation(
inputs_tensor, generation_config, model_kwargs
)
elif kwargs_has_attention_mask:
# TODO (joao): generalize this check with other types of inputs
if model_input_name == "input_ids" and len(model_kwargs["attention_mask"].shape) > 2:
raise ValueError("`attention_mask` passed to `generate` must be 2D.")
if model.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
# if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
model_kwargs = model._prepare_encoder_decoder_kwargs_for_generation(
inputs_tensor, model_kwargs, model_input_name, generation_config
)
# 5. Prepare `input_ids` which will be used for auto-regressive generation
if model.config.is_encoder_decoder:
input_ids, model_kwargs = model._prepare_decoder_input_ids_for_generation(
batch_size=batch_size,
model_input_name=model_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=generation_config._decoder_start_token_tensor,
device=inputs_tensor.device,
)
else:
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
# Expand inputs depending on the generation mode
input_ids, model_kwargs = model._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=max(generation_config.num_beams, generation_config.num_return_sequences),
is_encoder_decoder=model.config.is_encoder_decoder,
**model_kwargs,
)
if generation_config.token_healing:
input_ids = model.heal_tokens(input_ids, generation_mode_kwargs.get("tokenizer"))
if streamer is not None:
streamer.put(input_ids.cpu())
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_length = input_ids.shape[1]
generation_config = model._prepare_generated_length(
generation_config=generation_config,
has_default_max_length=has_default_max_length,
has_default_min_length=has_default_min_length,
model_input_name=model_input_name,
inputs_tensor=inputs_tensor,
input_ids_length=input_ids_length,
)
# If the model supports `logits_to_keep` in forward(), set it to 1 to avoid computing the whole
# logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding
# dynamically overrides this value as it can need more than the last token logits
if model._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs:
model_kwargs["logits_to_keep"] = 1
model._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
# 7. Prepare the cache.
max_cache_length = generation_config.max_length - 1
if (
inputs_tensor.shape[1] != input_ids_length
and model_input_name == "inputs_embeds"
and not model.config.is_encoder_decoder
):
max_cache_length += inputs_tensor.shape[1]
model._prepare_cache_for_generation(
generation_config, model_kwargs, generation_mode, batch_size, max_cache_length
)
if model.device.type != input_ids.device.type:
warnings.warn(
"You are calling .generate() with the `input_ids` being on a device type different"
f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
f" is on {model.device.type}. You may experience unexpected behaviors or slower generation."
" Please make sure that you have put `input_ids` to the"
f" correct device by calling for example input_ids = input_ids.to('{model.device.type}') before"
" running `.generate()`.",
UserWarning,
)
# 8. Prepare logits processors and stopping criteria
prepared_logits_processor = model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
device=inputs_tensor.device,
model_kwargs=model_kwargs,
negative_prompt_ids=negative_prompt_ids,
negative_prompt_attention_mask=negative_prompt_attention_mask,
)
prepared_stopping_criteria = model._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=stopping_criteria,
tokenizer=generation_mode_kwargs.get("tokenizer"),
)
# Set model_kwargs `use_cache` so we can use it later in forward runs
model_kwargs["use_cache"] = generation_config.use_cache
# 9. Call generation mode
# For assisted generation, use our custom function
if generation_mode == GenerationMode.ASSISTED_GENERATION:
result = _assisted_decoding(
model,
input_ids,
logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria,
generation_config=generation_config,
**generation_mode_kwargs,
**model_kwargs,
)
else:
# For other modes, use the model's standard methods
result = decoding_method(
model,
input_ids,
logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria,
generation_config=generation_config,
**generation_mode_kwargs,
**model_kwargs,
)
return result
# def _speculative_backoff_sampling(
# candidate_input_ids,
# candidate_logits,
# candidate_logits_unprocessed,
# eos_position_logits,
# candidate_length,
# new_logits,
# new_logits_unprocessed,# NOTE: these are unprocessed, unwarped logits
# is_done_candidate,
# div_threshold,
# div_type,
# do_sample, # this is also passed in new
# logits_processor: LogitsProcessorList, # these two must be passed in because we want to work with the logits before they are processed and warped
# logits_warper: Optional[LogitsProcessorList], # these two must be passed in because we want to work with the logits before they are processed and warped
# div_logits_processor: Optional[LogitsProcessorList],
# cur_len,
# eos_token_id,
# candidate_generator_type='classifier',
# ):
# # valid_tokens, n_matches, new_logits = _speculative_backoff_sampling(
# # candidate_input_ids,
# # candidate_logits,
# # candidate_logits_unprocessed,
# # candidate_length,
# # new_logits,
# # is_done_candidate,
# # kl_div_threshold,
# # do_sample,
# # logits_processor,
# # logits_warper,
# # cur_len,
# # )
# """
# Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns
# the selected tokens, as well as the number of candidate matches.
# NOTE: Unless otherwise stated, the variable names match those in the paper.
# """
# '''
# NOTE: Implementation plan -
# 1. implement custom assistent model class with classifier that terminates generation as soon as last generated logit is predicted to exceed distribution
# Is there an issue with using EOS token to terminate sequence? since large model will simply reject this token once it checks.
# I think this would work, since we can then use distribution generated by large model to generate next token (the position deemed as large model-necessary by classifier)
# 2. implement custom candidate_generator that uses this model to generate a series of candidates - DONE (other than question about do_sample - will set to sample for now)
# 3. implement this speculative_backoff_sampling class to backtrack, checking all candidates to see if they exceed the threshold. If they do, sample from large_model logits at this position (have to adjust logits as would is regular sampling)
# Need to make sure logit processing and warping is correct - both in terms of warping before calling this function (so that M_L sampling is correct) and in terms of having the warping not throw of the Kl divergence calculation
# Probably will pass original + processed logits into speculative_backoff_decoding function
# 4. Update cache of both assistant and target model to discard all KV values past first rejected token using cache.crop()
# 5. Make sure this is properly implemented within a loop, such that following all this candidate_generator is called again to generate the next batch of tokens
# '''
# initial_start_time = time.time()
# new_candidate_input_ids = candidate_input_ids[:, -candidate_length:]
# correction_term = 0
# if div_type != 'sd':
# if div_type == 'kl_div_processed' or div_type == 'js_div_processed' or div_type == 'tv_div_processed':
# epsilon = 1e-10
# q = candidate_logits.softmax(dim=-1)
# p = new_logits[:, :candidate_length, :].softmax(dim=-1) # need to be cropped because M_L logits include logits for ungenerated position
# q_nonzero = (p > 0).int()
# p_nonzero = (q > 0).int()
# both_nonzero = (q_nonzero & p_nonzero).int()
# # print(f"nonzero q: {q_nonzero.sum(dim=-1)}")
# # print(f"nonzero p: {p_nonzero.sum(dim=-1)}")
# # print(f"both nonzero: {both_nonzero.sum(dim=-1)}")
# q = q + epsilon
# p = p + epsilon
# p = p / p.sum(dim=-1, keepdim=True)
# q = q / q.sum(dim=-1, keepdim=True)
# else:
# q = candidate_logits_unprocessed.softmax(dim=-1)
# p = new_logits_unprocessed[:, :candidate_length, :].softmax(dim=-1) # need to be cropped because M_L logits include logits for ungenerated position
# if len(div_logits_processor) > 0:
# epsilon = 1e-10
# q = q + epsilon
# p = p + epsilon
# p = p / p.sum(dim=-1, keepdim=True)
# q = q / q.sum(dim=-1, keepdim=True)
# if div_type == 'kl_div' or div_type == 'kl_div_processed':
# divs = torch.nn.functional.kl_div(torch.log(p), q, reduction='none').sum(dim=-1) # shape = [bs, seq_len]
# elif div_type == 'kl_div_reversed' or div_type == 'kl_div_reversed_processed':
# divs = torch.nn.functional.kl_div(torch.log(q), p, reduction='none').sum(dim=-1) # shape = [bs, seq_len]
# elif div_type == 'js_div' or div_type == 'js_div_processed':
# m = 0.5 * (p + q) # Midpoint distribution
# divs = (0.5 * torch.nn.functional.kl_div(torch.log(p), m, reduction='none') + 0.5 * torch.nn.functional.kl_div(torch.log(q), m, reduction='none')).sum(dim=-1)
# elif div_type == 'tv_div' or div_type == 'tv_div_processed':
# divs = 0.5 * torch.abs(p - q).sum(dim=-1)
# elif div_type == 'top_p_kl_div' or div_type == 'top_p_js_div' or div_type == 'top_p_tv_div':
# p_sorted, p_sorted_indexes = torch.sort(p, descending=True)
# q_sorted = q[p_sorted_indexes]
# cum_p = torch.cumsum(p_sorted, dim=-1)
# # Identify the top-p (nucleus) indices
# top_p_mask = cum_p <= top_val
# top_p_mask[torch.argmax(cum_p > top_val)] = True # Include the first value exceeding p
# top_p = p_sorted[top_p_mask]
# top_q = q_sorted[top_p_mask]
# # Normalize the nucleus probabilities
# top_p = top_p / top_p.sum()
# top_q = top_q / top_q.sum()
# if div_type == 'top_p_kl_div':
# divs = torch.nn.functional.kl_div(torch.log(top_p), top_q, reduction='none').sum(dim=-1)
# if div_type == 'top_p_js_div':
# m = 0.5 * (top_p + top_q) # Midpoint distribution
# divs = (0.5 * torch.nn.functional.kl_div(torch.log(top_p), m, reduction='none') + 0.5 * torch.nn.functional.kl_div(torch.log(top_q), m, reduction='none')).sum(dim=-1)
# if div_type == 'top_p_tv_div':
# divs = 0.5 * torch.abs(top_p - top_q).sum(dim=-1)
# elif div_type == 'top_k_kl_div' or div_type == 'top_k_js_div' or div_type == 'top_k_tv_div':
# top_val = 50
# # print(f"p distr: {p}")
# # print(f"q distr: {q}")
# p_top_k, p_top_k_indices = torch.topk(p, top_val, dim=-1)
# q_top_k = torch.gather(q, -1, p_top_k_indices)
# top_k_mask = torch.zeros_like(p, dtype=torch.bool).scatter_(-1, p_top_k_indices, True)
# non_top_k_mask = ~top_k_mask # Invert the mask
# p_non_top_k_values = p * non_top_k_mask # Zero out the top_k values
# q_non_top_k_values = q * non_top_k_mask # Zero out the top_k values
# # Sum over the non-top_k positions
# p_non_top_k_sum = p_non_top_k_values.sum(dim=-1, keepdim=True)
# q_non_top_k_sum = q_non_top_k_values.sum(dim=-1, keepdim=True)
# # print(f"p_non_top_k_sum: {p_non_top_k_sum}")
# # p_non_top_k_sum = 1 - p_top_k.sum(dim=-1, keepdim=True)
# # q_non_top_k_sum = 1 - q_top_k.sum(dim=-1, keepdim=True)
# p_top_k = torch.cat((p_top_k, p_non_top_k_sum), dim=-1)
# q_top_k = torch.cat((q_top_k, q_non_top_k_sum), dim=-1)
# # print(f"p_top_k.shape: {p_top_k.shape}")
# # print(f"q_top_k.shape: {q_top_k.shape}")
# # p_top_k, p_top_k_indices = torch.topk(p, top_val, dim=-1)
# # q_top_k = q[:, :, p_top_k_indices]
# if div_type == 'top_k_kl_div':
# divs = torch.nn.functional.kl_div(torch.log(p_top_k), q_top_k, reduction='none').sum(dim=-1)
# if div_type == 'top_k_js_div':
# m = 0.5 * (p_top_k + q_top_k) # Midpoint distribution
# divs = (0.5 * torch.nn.functional.kl_div(torch.log(p_top_k), m, reduction='none') + 0.5 * torch.nn.functional.kl_div(torch.log(q_top_k), m, reduction='none')).sum(dim=-1)
# if div_type == 'top_k_tv_div':
# divs = 0.5 * torch.abs(p_top_k - q_top_k).sum(dim=-1)
# print(f"divs: {divs}")
# is_accepted = divs <= div_threshold
# print(f"divs: {divs.tolist()} threshold: {div_threshold} div_type: {div_type}")
# else:
# q = candidate_logits_unprocessed.softmax(dim=-1) # depends on whether processing candidate_logits or not
# q_i = q[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
# p = new_logits.softmax(dim=-1)
# p_i = p[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
# # print(f"SD in SBD - q: {q}, \np: {p}")
# probability_ratio = p_i / q_i
# # When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller
# # than the model probability for the same token"), keep the token. Otherwise reject with p = 1 - probability_ratio
# # (= keep with p = probability_ratio). Keep all the tokens until the first rejection
# r_i = torch.rand_like(probability_ratio)
# divs = r_i
# is_accepted = r_i <= probability_ratio
# # print(f"kl_div: {kl_div_threshold}")
# acceptance_time = time.time() - initial_start_time
# start_time = time.time()
# # print(f"acceptance time: {acceptance_time}")
# # print(f"divs: {divs}")
# # true_kl_divs = kl_divs.clone()
# if eos_position_logits != None:
# true_divs = divs.clone()
# eos_position_probs = eos_position_logits.softmax(dim=-1)
# eos_position_div = torch.nn.functional.kl_div(torch.log(p[:, -1, :].unsqueeze(1)), eos_position_probs, reduction='none').sum(dim=-1)
# true_divs[:, -1] = eos_position_div
# else:
# true_divs = divs
# # print(f"divs: {true_divs.tolist()}")
# # print(f"div_threshold: {div_threshold}")
# # labels = (kl_divs <= kl_div_threshold).int()
# n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1 -
# # Process and warp the logits before sampling
# # if len(logits_processor) > 0:
# # for i in range(n_matches + 1):
# # new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
# # if do_sample and len(logits_warper) > 0:
# # for i in range(n_matches + 1):
# # new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
# logit_processing_time = time.time() - start_time
# start_time = time.time()
# # print(f"new_logits shape inside: {new_logits.shape}")
# # print(f"logit_processing_time: {logit_processing_time}")
# # print(f"candidate_generator_type: {candidate_generator_type}")
# if candidate_length == n_matches and new_candidate_input_ids[0, -1] == eos_token_id and candidate_generator_type != 'regular' and div_type != 'sd':
# # print(f"Accepted an eos_token")
# is_done_candidate = True
# is_done_time = time.time() - start_time
# start_time = time.time()
# # print(f"is_done_time: {is_done_time}")
# if is_done_candidate and n_matches == candidate_length:
# backoff_count = n_matches
# total = candidate_length
# n_matches -= 1
# correction_term = 1
# valid_tokens = new_candidate_input_ids[:, : n_matches + 1]
# else:
# if div_type != 'sd':
# p_n_plus_1 = new_logits.softmax(dim=-1)[:, n_matches, :] # need to reuse new_logits because want to do post processing
# p_prime = p_n_plus_1 # this is the distribution at the position we must sample from to replace the first rejection
# # token selection
# if do_sample:
# next_tokens = torch.multinomial(p_prime, num_samples=1)# .squeeze(1) # check that distributions are adjusted accordingly before being passed into this.
# else:
# next_tokens = torch.argmax(p_prime, dim=-1)
# # The selected tokens include the matches (if any) plus the next sampled tokens
# if n_matches > 0:
# valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], next_tokens), dim=-1)
# else:
# valid_tokens = next_tokens
# else:
# gamma = candidate_logits.shape[1]
# p_n_plus_1 = p[:, n_matches, :]
# if n_matches < gamma:
# q_n_plus_1 = q[:, n_matches, :]
# p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0)
# p_prime.div_(p_prime.sum())
# else:
# p_prime = p_n_plus_1
# # print(f"p_prime: {p_prime}")
# t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :]
# # The selected tokens include the matches (if any) plus the next sampled tokens
# if n_matches > 0:
# valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1)
# else:
# valid_tokens = t
# print(f"SBD: candidate_length: {candidate_length}, n_matches: {n_matches}")
# # if candidate_length != 5:
# # print(f"prediction: {true_divs[:, -1].item() > div_threshold}")
# # spec_sampling_time = (time.time() - start_time) + acceptance_time
# spec_sampling_time = time.time() - start_time
# # print(f"spec_sampling_time: {spec_sampling_time}")
# total_time = time.time() - initial_start_time
# # print(f"total_time: {total_time} == {acceptance_time + logit_processing_time + is_done_time + spec_sampling_time}")
# # print(f"total_time without processing: {total_time - logit_processing_time}")
# return valid_tokens, n_matches, new_logits, correction_term, true_divs, acceptance_time, spec_sampling_time