|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
|
import weakref |
|
|
from typing import TYPE_CHECKING, Any, Optional |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from ..pytorch_utils import prune_linear_layer |
|
|
from ..utils import is_sklearn_available |
|
|
|
|
|
|
|
|
if is_sklearn_available(): |
|
|
from sklearn.metrics import roc_curve |
|
|
|
|
|
from ..pytorch_utils import isin_mps_friendly |
|
|
from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor, SuppressTokensLogitsProcessor |
|
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from ..modeling_utils import PreTrainedModel |
|
|
from ..tokenization_utils_base import PreTrainedTokenizerBase |
|
|
from .configuration_utils import GenerationConfig |
|
|
|
|
|
|
|
|
class CandidateGenerator: |
|
|
"""Abstract base class for all candidate generators that can be applied during assisted generation.""" |
|
|
|
|
|
def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, Optional[torch.FloatTensor]]: |
|
|
""" |
|
|
Fetches the candidates to be tried for the current input. |
|
|
|
|
|
Args: |
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
|
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) |
|
|
|
|
|
Return: |
|
|
`torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be |
|
|
assessed by the model and, optionally, a `torch.FloatTensor` of shape `(batch_size, candidate_length, |
|
|
vocabulary_size)` containing the logits associated to each candidate. |
|
|
""" |
|
|
raise NotImplementedError( |
|
|
f"{self.__class__} is an abstract class. Only classes inheriting this class can call `get_candidates`." |
|
|
) |
|
|
|
|
|
def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): |
|
|
""" |
|
|
Updates the candidate generation strategy based on the outcomes. |
|
|
|
|
|
Args: |
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
|
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) |
|
|
scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`): |
|
|
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using |
|
|
beam search or log softmax for each vocabulary token when using beam search |
|
|
num_matches (`int`): |
|
|
The number of matches between the candidate sequences and the model predictions. |
|
|
""" |
|
|
raise NotImplementedError( |
|
|
f"{self.__class__} is an abstract class. Only classes inheriting this class can call " |
|
|
"`update_candidate_strategy`." |
|
|
) |
|
|
|
|
|
|
|
|
class AssistedCandidateGenerator(CandidateGenerator): |
|
|
""" |
|
|
`CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates |
|
|
candidates through the use of a smaller model. Read the following blog post for more information: |
|
|
https://huggingface.co/blog/assisted-generation |
|
|
|
|
|
Args: |
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
|
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) |
|
|
assistant_model (`PreTrainedModel`): |
|
|
The model to be used for generating candidates. This model should be smaller than the main model. |
|
|
generation_config (`~generation.GenerationConfig`, *optional*): |
|
|
The generation configuration to be used as base parametrization for the generation call. |
|
|
logits_processor (`LogitsProcessorList`): |
|
|
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] |
|
|
used to modify the prediction scores of the language modeling head applied at each generation step. |
|
|
model_kwargs (`Dict`): |
|
|
The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant |
|
|
model as well. |
|
|
inputs_tensor (`torch.Tensor`, *optional*): |
|
|
The model input tensor. In encoder-decoder models, this is the encoder input. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
assistant_model: "PreTrainedModel", |
|
|
generation_config: "GenerationConfig", |
|
|
model_kwargs: dict, |
|
|
inputs_tensor: Optional[torch.Tensor] = None, |
|
|
logits_processor: Optional["LogitsProcessorList"] = None, |
|
|
): |
|
|
|
|
|
device = assistant_model.device |
|
|
input_ids = input_ids.to(device) |
|
|
if inputs_tensor is not None: |
|
|
inputs_tensor = inputs_tensor.to(device) |
|
|
|
|
|
|
|
|
self.assistant_model = assistant_model |
|
|
self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens |
|
|
self.assistant_confidence_threshold = assistant_model.generation_config.assistant_confidence_threshold |
|
|
|
|
|
|
|
|
self.assistant_model.generation_config.eos_token_id = generation_config.eos_token_id |
|
|
|
|
|
|
|
|
assistant_kwargs = {} |
|
|
for key, value in model_kwargs.items(): |
|
|
if key not in ("encoder_outputs", "past_key_values"): |
|
|
assistant_kwargs[key] = ( |
|
|
value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value) |
|
|
) |
|
|
|
|
|
|
|
|
if "logits_to_keep" in assistant_kwargs and not assistant_model._supports_logits_to_keep(): |
|
|
del assistant_kwargs["logits_to_keep"] |
|
|
|
|
|
|
|
|
if assistant_model.config.is_encoder_decoder: |
|
|
inputs_tensor, model_input_name, assistant_kwargs = assistant_model._prepare_model_inputs( |
|
|
inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_kwargs |
|
|
) |
|
|
assistant_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation( |
|
|
inputs_tensor, assistant_kwargs, model_input_name, assistant_model.generation_config |
|
|
) |
|
|
elif "encoder_outputs" in model_kwargs: |
|
|
assistant_kwargs["encoder_outputs"] = model_kwargs["encoder_outputs"] |
|
|
self.assistant_kwargs = assistant_kwargs |
|
|
|
|
|
|
|
|
if assistant_model.config.is_encoder_decoder: |
|
|
|
|
|
self.input_ids_key = "decoder_input_ids" |
|
|
elif "encoder_outputs" in assistant_kwargs: |
|
|
|
|
|
self.input_ids_key = "input_ids" |
|
|
self.assistant_kwargs["attention_mask"] = self.assistant_kwargs.get( |
|
|
"decoder_attention_mask", |
|
|
torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long), |
|
|
) |
|
|
else: |
|
|
|
|
|
self.input_ids_key = "input_ids" |
|
|
|
|
|
|
|
|
self.logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |
|
|
self.generation_config = copy.deepcopy(generation_config) |
|
|
|
|
|
self.generation_config.return_dict_in_generate = True |
|
|
self.generation_config.output_scores = True |
|
|
self.generation_config.assistant_confidence_threshold = self.assistant_confidence_threshold |
|
|
|
|
|
self.generation_config.is_assistant = True |
|
|
|
|
|
|
|
|
|
|
|
self.main_model_min_length = self.generation_config.min_length |
|
|
self.generation_config.min_length = 0 |
|
|
self.generation_config.min_new_tokens = None |
|
|
for processor in self.logits_processor: |
|
|
if isinstance(processor, MinLengthLogitsProcessor): |
|
|
raise ValueError( |
|
|
"Passing `MinLengthLogitsProcessor` when using `assisted_generation is disabled. " |
|
|
"Please pass in `min_length` into `.generate()` instead" |
|
|
) |
|
|
|
|
|
|
|
|
self.generation_config.cache_implementation = "dynamic_full" |
|
|
|
|
|
if ( |
|
|
is_sklearn_available() |
|
|
and self.assistant_model.generation_config.assistant_confidence_threshold |
|
|
and type(self) is AssistedCandidateGenerator |
|
|
): |
|
|
self.probs = [] |
|
|
self.matches = [] |
|
|
|
|
|
def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, Optional[torch.FloatTensor]]: |
|
|
""" |
|
|
Fetches the candidates to be tried for the current input. |
|
|
|
|
|
Args: |
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
|
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) |
|
|
|
|
|
Return: |
|
|
`torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be |
|
|
assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length, |
|
|
vocabulary_size)` containing the logits associated to each candidate. |
|
|
""" |
|
|
input_ids = input_ids.to(self.assistant_model.device) |
|
|
|
|
|
min_new_tokens, max_new_tokens = self._calculate_new_tokens(input_ids) |
|
|
if max_new_tokens == 0: |
|
|
return input_ids, None |
|
|
|
|
|
self._update_past_and_masks(input_ids) |
|
|
|
|
|
generation_args = self._prepare_generation_args(input_ids, min_new_tokens, max_new_tokens) |
|
|
candidate_ids, candidate_logits = self._generate_candidates(generation_args) |
|
|
return candidate_ids, candidate_logits |
|
|
|
|
|
def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): |
|
|
""" |
|
|
Updates the candidate generation strategy based on the outcomes. |
|
|
|
|
|
Args: |
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
|
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) |
|
|
scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`): |
|
|
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using |
|
|
beam search or log softmax for each vocabulary token when using beam search |
|
|
num_matches (`int`): |
|
|
The number of matches between the candidate sequences and the model predictions. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
if self.assistant_model.generation_config.num_assistant_tokens_schedule in { |
|
|
"heuristic", |
|
|
"heuristic_transient", |
|
|
}: |
|
|
|
|
|
if num_matches == len(scores[0]) - 1: |
|
|
self.num_assistant_tokens += 2.0 |
|
|
else: |
|
|
self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0) |
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
|
is_sklearn_available() |
|
|
and self.assistant_model.generation_config.assistant_confidence_threshold |
|
|
and type(self) is AssistedCandidateGenerator |
|
|
): |
|
|
|
|
|
self.matches.extend([1] * num_matches) |
|
|
if len(self.probs) > len(self.matches): |
|
|
self.matches.append(0) |
|
|
|
|
|
|
|
|
excess_length = len(self.probs) - len(self.matches) |
|
|
if excess_length > 0: |
|
|
del self.probs[-excess_length:] |
|
|
|
|
|
if ( |
|
|
len(self.probs) > 5 and {0, 1}.issubset(self.matches) |
|
|
): |
|
|
fpr, tpr, thresholds = roc_curve(self.matches, self.probs) |
|
|
fnr = 1 - tpr |
|
|
|
|
|
|
|
|
costs = fpr + 3 * fnr |
|
|
|
|
|
|
|
|
optimal_threshold_index = np.argmin(costs) |
|
|
best_threshold = thresholds[optimal_threshold_index] |
|
|
|
|
|
self.assistant_model.generation_config.assistant_confidence_threshold = best_threshold |
|
|
|
|
|
def _calculate_new_tokens(self, input_ids: torch.LongTensor) -> tuple[int, int]: |
|
|
"""Calculate the minimum and maximum number of new tokens to generate.""" |
|
|
new_cur_len = input_ids.shape[-1] |
|
|
max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1) |
|
|
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0) |
|
|
return min_new_tokens, max_new_tokens |
|
|
|
|
|
def _update_past_and_masks( |
|
|
self, input_ids: torch.LongTensor, remove_from_pkv: int = 0, num_added_tokens: int = 1 |
|
|
) -> bool: |
|
|
"""Update past key values and attention masks for subsequent generation rounds.""" |
|
|
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None |
|
|
if has_past_key_values: |
|
|
new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv |
|
|
self.assistant_kwargs["past_key_values"].crop(new_cache_size - num_added_tokens) |
|
|
self.assistant_kwargs = _prepare_attention_mask( |
|
|
self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder |
|
|
) |
|
|
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, input_ids.shape[-1]) |
|
|
|
|
|
|
|
|
|
|
|
self.generation_config.cache_implementation = None |
|
|
|
|
|
return has_past_key_values |
|
|
|
|
|
def _prepare_generation_args(self, input_ids: torch.LongTensor, min_new_tokens: int, max_new_tokens: int) -> dict: |
|
|
"""Prepare arguments for the generation call.""" |
|
|
return { |
|
|
self.input_ids_key: input_ids, |
|
|
"min_new_tokens": min_new_tokens, |
|
|
"max_new_tokens": max_new_tokens, |
|
|
"generation_config": self.generation_config, |
|
|
"logits_processor": self.logits_processor, |
|
|
} |
|
|
|
|
|
def _generate_candidates(self, generation_args: dict) -> tuple[torch.LongTensor, Optional[torch.FloatTensor]]: |
|
|
"""Generate candidate sequences using the assistant model.""" |
|
|
assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs) |
|
|
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values |
|
|
if ( |
|
|
is_sklearn_available() |
|
|
and self.assistant_model.generation_config.assistant_confidence_threshold |
|
|
and type(self) is AssistedCandidateGenerator |
|
|
): |
|
|
scores_tensor = torch.cat(assistant_output.scores, dim=0) |
|
|
scores_softmax = torch.softmax(scores_tensor, dim=-1) |
|
|
ids = assistant_output.sequences[-1, -len(assistant_output.scores) :] |
|
|
p = scores_softmax[range(len(ids)), ids] |
|
|
self.probs.extend(p.tolist()) |
|
|
candidate_logits = torch.stack(assistant_output.scores, dim=1) |
|
|
candidate_ids = assistant_output.sequences |
|
|
return candidate_ids, candidate_logits |
|
|
|
|
|
|
|
|
class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator): |
|
|
""" |
|
|
`CandidateGenerator` class to be used for Universal Assisted Generation (UAD): assisted generation with different tokenizers |
|
|
for the assistant and main models. This class generates candidates through the use of a smaller |
|
|
model. |
|
|
|
|
|
The main model input tokens are re-encoded into assistant model tokens, then candidate tokens are generated in the assistant encoding, which are |
|
|
in turn re-encoded into main model candidate tokens. Validation then proceeds as explained above. |
|
|
The re-encoding steps involve decoding token ids into text and then encoding the text using a different tokenizer. |
|
|
Since re-encoding the tokens may result in tokenization discrepancies, UAD finds the longest common subsequence between the source and target encodings, |
|
|
to ensure the new tokens include the correct prompt suffix. |
|
|
|
|
|
Args: |
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
|
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) |
|
|
assistant_model (`PreTrainedModel`): |
|
|
The model to be used for generating candidates. This model should be smaller than the main model. |
|
|
target_tokenizer (`PreTrainedTokenizerBase`): |
|
|
The tokenizer used for the target model. |
|
|
assistant_tokenizer (`PreTrainedTokenizerBase`): |
|
|
The tokenizer used for the assistant model. |
|
|
generation_config (`~generation.GenerationConfig`, *optional*): |
|
|
The generation configuration to be used as base parametrization for the generation call. |
|
|
logits_processor (`LogitsProcessorList`): |
|
|
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] |
|
|
used to modify the prediction scores of the language modeling head applied at each generation step. |
|
|
model_kwargs (`Dict`): |
|
|
The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant |
|
|
model as well. |
|
|
inputs_tensor (`torch.Tensor`, *optional*): |
|
|
The model input tensor. In encoder-decoder models, this is the encoder input. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
assistant_model: "PreTrainedModel", |
|
|
target_tokenizer: "PreTrainedTokenizerBase", |
|
|
assistant_tokenizer: "PreTrainedTokenizerBase", |
|
|
generation_config: "GenerationConfig", |
|
|
model_kwargs: dict, |
|
|
inputs_tensor: Optional[torch.Tensor] = None, |
|
|
logits_processor: Optional["LogitsProcessorList"] = None, |
|
|
): |
|
|
super().__init__(input_ids, assistant_model, generation_config, model_kwargs, inputs_tensor, logits_processor) |
|
|
|
|
|
self.target_tokenizer = target_tokenizer |
|
|
self.assistant_tokenizer = assistant_tokenizer |
|
|
self.prev_target_ids_len: Optional[int] = None |
|
|
self.prev_assistant_ids = None |
|
|
self.target_lookbehind = assistant_model.generation_config.target_lookbehind |
|
|
self.assistant_lookbehind = assistant_model.generation_config.assistant_lookbehind |
|
|
|
|
|
@staticmethod |
|
|
def _get_longest_diag_dict(input_matrix, nonzero_idx): |
|
|
""" |
|
|
Calculates the length of the longest diagonal sequence in a given matrix. |
|
|
Args: |
|
|
input_matrix (torch.Tensor): The input matrix. |
|
|
nonzero_idx (torch.Tensor): The indices of the non-zero elements in the matrix. |
|
|
Returns: |
|
|
dict: A dictionary where the keys are the indices of the non-zero elements and the values are the lengths of the longest diagonal sequences starting from those indices. |
|
|
""" |
|
|
|
|
|
visited = set() |
|
|
diags = {} |
|
|
for idx in nonzero_idx: |
|
|
start_idx = torch.clone(idx) |
|
|
tuple_start_idx = tuple(start_idx.tolist()) |
|
|
|
|
|
if tuple_start_idx in visited: |
|
|
continue |
|
|
|
|
|
visited.add(tuple_start_idx) |
|
|
cur_diag_len = 1 |
|
|
start_idx += 1 |
|
|
while start_idx[0] < input_matrix.shape[0] and start_idx[1] < input_matrix.shape[1]: |
|
|
tuple_start_idx = tuple(start_idx.tolist()) |
|
|
visited.add(tuple_start_idx) |
|
|
|
|
|
if input_matrix[start_idx[0], start_idx[1]] == 1: |
|
|
cur_diag_len += 1 |
|
|
start_idx += 1 |
|
|
else: |
|
|
break |
|
|
|
|
|
diags[idx] = cur_diag_len |
|
|
return diags |
|
|
|
|
|
@staticmethod |
|
|
def _get_longest_diag_index(input_matrix): |
|
|
""" |
|
|
Returns the start index and length of the longest diagonal in the given input. |
|
|
Args: |
|
|
input_matrix (numpy.ndarray): The input matrix. |
|
|
Returns: |
|
|
tuple: A tuple containing the start index and length of the longest diagonal. |
|
|
""" |
|
|
|
|
|
diags = AssistedCandidateGeneratorDifferentTokenizers._get_longest_diag_dict( |
|
|
input_matrix, input_matrix.nonzero() |
|
|
) |
|
|
diags_values = list(diags.values()) |
|
|
diags_keys = list(diags.keys()) |
|
|
best_diag = np.argmax(diags_values) |
|
|
diag_start_index = diags_keys[best_diag] |
|
|
diag_start_length = diags_values[best_diag] |
|
|
return diag_start_index, diag_start_length |
|
|
|
|
|
@staticmethod |
|
|
def _get_tokens_diag(prompt, prompt_plus_new_tokens): |
|
|
""" |
|
|
Input: |
|
|
prompt: 2D array of shape (batch_size, prompt_length), represents the original prompt tokens |
|
|
prompt_plus_new_tokens: 2D array of shape (batch_size, prompt_length), represents the suffix of the original prompt, with additional new tokens. |
|
|
Output: |
|
|
discrepancy_length: int, represents the number of tokens that need to be replaced from prompt |
|
|
new_tokens_only: 2D array of shape (batch_size, new_token_length), represents the new tokens that are not in prompt |
|
|
discrepancy_only: 2D array of shape (batch_size, discrepancy_length), represents the new tokens that are in prompt but not in prompt_plus_new_tokens |
|
|
""" |
|
|
compare_mat = prompt_plus_new_tokens.T == prompt |
|
|
if not torch.is_tensor(compare_mat): |
|
|
compare_mat = torch.tensor(compare_mat) |
|
|
|
|
|
compare_mat_int = compare_mat.to(int) |
|
|
|
|
|
if not compare_mat_int.any().item(): |
|
|
|
|
|
return None, None, None |
|
|
|
|
|
longest_location, longest_diag_length = AssistedCandidateGeneratorDifferentTokenizers._get_longest_diag_index( |
|
|
compare_mat_int |
|
|
) |
|
|
new_token_start_index = longest_location[0] + longest_diag_length |
|
|
discrepancy_with_old = longest_location[1] + longest_diag_length |
|
|
discrepancy_length = (prompt.shape[1] - discrepancy_with_old).item() |
|
|
new_tokens_only = prompt_plus_new_tokens[:, new_token_start_index + discrepancy_length :] |
|
|
discrepancy_only = prompt_plus_new_tokens[ |
|
|
:, new_token_start_index : new_token_start_index + discrepancy_length |
|
|
] |
|
|
return discrepancy_length, new_tokens_only, discrepancy_only |
|
|
|
|
|
def convert_source_tokens_to_target_tokens( |
|
|
self, |
|
|
input_ids, |
|
|
source_tokenizer, |
|
|
destination_tokenizer, |
|
|
): |
|
|
""" |
|
|
Convert token IDs from one tokenizer to another. |
|
|
Args: |
|
|
input_ids: The input token IDs. |
|
|
source_tokenizer: The source tokenizer. |
|
|
destination_tokenizer: The destination tokenizer. |
|
|
Returns: |
|
|
The converted token IDs. |
|
|
""" |
|
|
text = source_tokenizer.batch_decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) |
|
|
dest_ids = destination_tokenizer(text, add_special_tokens=True, return_tensors="pt")["input_ids"] |
|
|
return dest_ids.to(input_ids.device) |
|
|
|
|
|
def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, Optional[torch.FloatTensor]]: |
|
|
""" |
|
|
Fetches the candidates to be tried for the current input. |
|
|
|
|
|
Args: |
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
|
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) |
|
|
|
|
|
Return: |
|
|
`torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be |
|
|
assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length, |
|
|
vocabulary_size)` containing the logits associated to each candidate. |
|
|
""" |
|
|
max_new_tokens = int(self.num_assistant_tokens) |
|
|
if max_new_tokens == 0: |
|
|
return input_ids, None |
|
|
|
|
|
input_ids = input_ids.to(self.assistant_model.device) |
|
|
remove_from_pkv = 0 |
|
|
|
|
|
assistant_input_ids, remove_from_pkv = self._prepare_assistant_input_ids(input_ids) |
|
|
self.prev_assistant_ids = assistant_input_ids |
|
|
|
|
|
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - assistant_input_ids.shape[-1]), 0) |
|
|
|
|
|
self._update_past_and_masks(assistant_input_ids, remove_from_pkv) |
|
|
generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens) |
|
|
self.assistant_kwargs.pop("attention_mask", None) |
|
|
|
|
|
assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs) |
|
|
new_target_ids = self._process_assistant_outputs(input_ids, assistant_output.sequences) |
|
|
|
|
|
|
|
|
self.prev_target_ids_len = input_ids.shape[1] |
|
|
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values |
|
|
self.prev_assistant_ids = assistant_output.sequences |
|
|
|
|
|
if self.prev_target_ids_len >= new_target_ids.shape[1]: |
|
|
return input_ids, None |
|
|
|
|
|
return new_target_ids, None |
|
|
|
|
|
def _prepare_assistant_input_ids(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, int]: |
|
|
"""Converts target input IDs to assistant input IDs, handling discrepancies.""" |
|
|
convert_kwargs = { |
|
|
"source_tokenizer": self.target_tokenizer, |
|
|
"destination_tokenizer": self.assistant_tokenizer, |
|
|
} |
|
|
remove_from_pkv = 0 |
|
|
|
|
|
if self.prev_assistant_ids is not None and self.prev_target_ids_len > self.target_lookbehind: |
|
|
|
|
|
start_index_in_target_window = self.prev_target_ids_len - self.target_lookbehind |
|
|
|
|
|
new_assistant_ids = self.convert_source_tokens_to_target_tokens( |
|
|
input_ids[:, start_index_in_target_window:], **convert_kwargs |
|
|
) |
|
|
prompt_use_length = new_assistant_ids.shape[1] |
|
|
prompt_use = self.prev_assistant_ids[:, -prompt_use_length:] |
|
|
|
|
|
discrepancy_length, new_tokens_only, discrepancy_only = self._get_tokens_diag( |
|
|
prompt_use, new_assistant_ids |
|
|
) |
|
|
assistant_input_ids = self.prev_assistant_ids |
|
|
|
|
|
if new_tokens_only is not None: |
|
|
if discrepancy_length > 0 and discrepancy_only.shape[1] > 0: |
|
|
if discrepancy_length == discrepancy_only.shape[1]: |
|
|
assistant_input_ids[:, -discrepancy_length:] = discrepancy_only |
|
|
|
|
|
elif discrepancy_length > discrepancy_only.shape[1]: |
|
|
discrepancy_length_diff = discrepancy_length - discrepancy_only.shape[1] |
|
|
assistant_input_ids = assistant_input_ids[:, :-discrepancy_length_diff] |
|
|
assistant_input_ids[:, -discrepancy_only.shape[1] :] = discrepancy_only |
|
|
|
|
|
remove_from_pkv = discrepancy_length |
|
|
|
|
|
if new_tokens_only.shape[1] > 0: |
|
|
assistant_input_ids = torch.cat([assistant_input_ids, new_tokens_only], dim=-1) |
|
|
else: |
|
|
|
|
|
assistant_input_ids = torch.cat([assistant_input_ids, new_assistant_ids], dim=-1) |
|
|
else: |
|
|
assistant_input_ids = self.convert_source_tokens_to_target_tokens(input_ids, **convert_kwargs) |
|
|
self.prev_target_ids_len = input_ids.shape[1] |
|
|
|
|
|
return assistant_input_ids, remove_from_pkv |
|
|
|
|
|
def _process_assistant_outputs( |
|
|
self, input_ids: torch.LongTensor, assistant_sequences: torch.LongTensor |
|
|
) -> torch.LongTensor: |
|
|
"""Processes assistant outputs to obtain target input IDs.""" |
|
|
num_prev_assistant = self.prev_assistant_ids.shape[1] |
|
|
start_assistant_look_index = num_prev_assistant - self.assistant_lookbehind |
|
|
|
|
|
new_target_ids_from_window = self.convert_source_tokens_to_target_tokens( |
|
|
assistant_sequences[:, start_assistant_look_index:], |
|
|
source_tokenizer=self.assistant_tokenizer, |
|
|
destination_tokenizer=self.target_tokenizer, |
|
|
) |
|
|
target_prompt_use_length = new_target_ids_from_window.shape[1] |
|
|
|
|
|
target_prompt_use = input_ids[:, -target_prompt_use_length:] |
|
|
|
|
|
_, target_new_tokens_only, _ = self._get_tokens_diag(target_prompt_use, new_target_ids_from_window) |
|
|
|
|
|
new_target_ids = input_ids |
|
|
|
|
|
if target_new_tokens_only is not None: |
|
|
if target_new_tokens_only.shape[1] > 0: |
|
|
new_target_ids = torch.cat([new_target_ids, target_new_tokens_only], dim=-1) |
|
|
else: |
|
|
|
|
|
new_target_ids = torch.cat([new_target_ids, new_target_ids_from_window], dim=-1) |
|
|
|
|
|
if hasattr(self.generation_config, "max_length"): |
|
|
new_target_ids = new_target_ids[:, : self.generation_config.max_length] |
|
|
|
|
|
return new_target_ids |
|
|
|
|
|
|
|
|
class _PruneReindexingLMHead(nn.Module): |
|
|
""" |
|
|
A class to prune and reindex the language model head. |
|
|
|
|
|
This class prunes the language model head to only include the specified token IDs and reindexes the logits |
|
|
to map back to the original vocabulary. |
|
|
|
|
|
Args: |
|
|
original_lm_head (nn.Module): The original language model head. |
|
|
token_ids (list[int]): The list of token IDs to keep. |
|
|
""" |
|
|
|
|
|
def __init__(self, original_lm_head, assistant_overlap_token_ids): |
|
|
super().__init__() |
|
|
self.pruned_lm_head = prune_linear_layer(original_lm_head, assistant_overlap_token_ids).to( |
|
|
original_lm_head.weight.dtype |
|
|
) |
|
|
|
|
|
def forward(self, hidden_states): |
|
|
pruned_logits = self.pruned_lm_head(hidden_states) |
|
|
return pruned_logits |
|
|
|
|
|
|
|
|
class _MapInputEmbedding(nn.Module): |
|
|
def __init__(self, original_embedding: nn.Embedding, assistant_overlap_token_ids): |
|
|
""" |
|
|
Wraps an existing embedding layer and remaps token IDs before lookup. |
|
|
|
|
|
Args: |
|
|
original_embedding (nn.Embedding): Pre-trained or existing embedding layer. |
|
|
assistant_overlap_token_ids (dict): Mapping from original token IDs to new token IDs. |
|
|
Example: {old_id: new_id} |
|
|
""" |
|
|
super().__init__() |
|
|
self.original_embedding = original_embedding |
|
|
self.weight = original_embedding.weight |
|
|
self.assistant_overlap_token_ids = assistant_overlap_token_ids |
|
|
self.map = False |
|
|
|
|
|
def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor: |
|
|
""" |
|
|
Args: |
|
|
input_ids (torch.LongTensor): Tensor of token IDs (batch_size, seq_len). |
|
|
|
|
|
Returns: |
|
|
torch.FloatTensor: Corresponding input embeddings. |
|
|
""" |
|
|
if self.map: |
|
|
|
|
|
my_input_ids = self.assistant_overlap_token_ids[input_ids[0, -1]].unsqueeze(0).unsqueeze(0) |
|
|
else: |
|
|
self.map = True |
|
|
my_input_ids = input_ids |
|
|
|
|
|
return self.original_embedding(my_input_ids) |
|
|
|
|
|
|
|
|
class AssistantToTargetTranslator: |
|
|
""" |
|
|
Translates token ids and logits between assistant and target model vocabularies. This class is used to handle |
|
|
vocabulary mismatches when using different tokenizers for the assistant and target models in speculative decoding, |
|
|
as introduced in the paper "Lossless Speculative Decoding Algorithms for Heterogeneous Vocabularies" |
|
|
(https://huggingface.co/papers/2502.05202). |
|
|
It maintains mappings between the two vocabularies and handles token/logit conversion. |
|
|
|
|
|
Args: |
|
|
target_tokenizer (`PreTrainedTokenizerBase`): |
|
|
The tokenizer used by the target (main) model. |
|
|
assistant_tokenizer (`PreTrainedTokenizerBase`): |
|
|
The tokenizer used by the assistant model. |
|
|
target_vocab_size (`int`): |
|
|
The size of the target model's vocabulary. If not provided, will be inferred from the target tokenizer. |
|
|
assistant_model (Optional[PreTrainedModel], optional): The assistant model to be used. Defaults to None for backward compatibility. |
|
|
assistant_prune_lm_head (bool): Whether to prune the assistant model's language model |
|
|
head to match the target vocabulary. This is only applicable if `assistant_model` is provided. |
|
|
Defaults to False for backward compatibility. |
|
|
""" |
|
|
|
|
|
FILTER_VALUE: float = -float("Inf") |
|
|
SUPPRESS_TOKEN_ID: int = -1 |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
target_tokenizer: "PreTrainedTokenizerBase", |
|
|
assistant_tokenizer: "PreTrainedTokenizerBase", |
|
|
target_vocab_size: int, |
|
|
assistant_model: Optional["PreTrainedModel"] = None, |
|
|
assistant_prune_lm_head: bool = False, |
|
|
): |
|
|
self._target_tokenizer: PreTrainedTokenizerBase = target_tokenizer |
|
|
self._assistant_tokenizer: PreTrainedTokenizerBase = assistant_tokenizer |
|
|
self._assistant_model_device = assistant_model.device if assistant_model is not None else "cpu" |
|
|
self.target_vocab_size: int = target_vocab_size |
|
|
self._assistant_to_target_input_ids, self.target_to_assistant_input_ids = ( |
|
|
self._get_assistant_to_target_input_ids() |
|
|
) |
|
|
self._suppress_input_ids: list[int] = self._get_suppress_input_ids() |
|
|
self.logits_processors: Optional[LogitsProcessorList] = None |
|
|
self.assistant_prune_lm_head = assistant_prune_lm_head and assistant_model is not None |
|
|
if len(self._suppress_input_ids) > 0: |
|
|
|
|
|
if self.assistant_prune_lm_head: |
|
|
self.assistant_overlap_token_ids = torch.tensor( |
|
|
list(self.target_to_assistant_input_ids.values()), |
|
|
dtype=torch.long, |
|
|
device=self._assistant_model_device, |
|
|
) |
|
|
original_lm_head = assistant_model.get_output_embeddings() |
|
|
pruned_lm_head = _PruneReindexingLMHead(original_lm_head, self.assistant_overlap_token_ids) |
|
|
del original_lm_head |
|
|
assistant_model.set_output_embeddings(pruned_lm_head) |
|
|
|
|
|
original_input_embeddings = assistant_model.get_input_embeddings() |
|
|
map_input_embeddings = _MapInputEmbedding(original_input_embeddings, self.assistant_overlap_token_ids) |
|
|
del original_input_embeddings |
|
|
assistant_model.set_input_embeddings(map_input_embeddings) |
|
|
self.map_input_embeddings = map_input_embeddings |
|
|
else: |
|
|
self.logits_processors = LogitsProcessorList( |
|
|
[SuppressTokensLogitsProcessor(self._get_suppress_input_ids(), self._assistant_model_device)] |
|
|
) |
|
|
|
|
|
def unmap_input_ids(self): |
|
|
""" |
|
|
Disables the mapping of input ids despite the assistant pruning for the language model head being enabled. |
|
|
|
|
|
This method is required for the first forward pass of `_MapInputEmbedding` where input ids are already in the assistant vocabulary space. By disabling the mapping, it ensures that the input ids are processed correctly without remapping. |
|
|
|
|
|
""" |
|
|
if self.assistant_prune_lm_head: |
|
|
self.map_input_embeddings.map = False |
|
|
|
|
|
def _get_assistant_to_target_input_ids(self): |
|
|
target_vocab = self._target_tokenizer.get_vocab() |
|
|
assistant_vocab = self._assistant_tokenizer.get_vocab() |
|
|
|
|
|
space_str = " " |
|
|
target_space_ids = self._target_tokenizer(space_str, add_special_tokens=False)["input_ids"] |
|
|
if len(target_space_ids) > 0: |
|
|
target_space_sign = self._target_tokenizer.convert_ids_to_tokens(target_space_ids)[0][0] |
|
|
|
|
|
assistant_space_ids = self._assistant_tokenizer(space_str, add_special_tokens=False)["input_ids"] |
|
|
if len(assistant_space_ids) > 0: |
|
|
assistant_space_sign = self._assistant_tokenizer.convert_ids_to_tokens(assistant_space_ids)[0][0] |
|
|
|
|
|
if target_space_sign != assistant_space_sign: |
|
|
|
|
|
|
|
|
assistant_vocab = { |
|
|
( |
|
|
tok.replace(assistant_space_sign, target_space_sign, 1) |
|
|
if tok.startswith(assistant_space_sign) |
|
|
else tok |
|
|
): idx |
|
|
for tok, idx in assistant_vocab.items() |
|
|
} |
|
|
|
|
|
max_assistant_index = max(assistant_vocab.values()) |
|
|
assistant_to_target_input_ids = torch.full((max_assistant_index + 1,), self.SUPPRESS_TOKEN_ID, dtype=int) |
|
|
target_to_assistant_input_ids: dict[int, int] = {} |
|
|
for tok, assistant_id in assistant_vocab.items(): |
|
|
target_id = target_vocab.get(tok) |
|
|
if target_id is not None: |
|
|
assistant_to_target_input_ids[assistant_id] = target_id |
|
|
target_to_assistant_input_ids[target_id] = assistant_id |
|
|
return assistant_to_target_input_ids.to(self._assistant_model_device), target_to_assistant_input_ids |
|
|
|
|
|
def _get_suppress_input_ids(self) -> list[int]: |
|
|
""" |
|
|
Get the input ids that are in the assistant vocab but not in the target vocab. |
|
|
""" |
|
|
return torch.where(self._assistant_to_target_input_ids == self.SUPPRESS_TOKEN_ID)[0] |
|
|
|
|
|
def get_target_ids( |
|
|
self, assistant_input_ids, target_input_ids, assistant_candidate_ids: torch.LongTensor |
|
|
) -> torch.LongTensor: |
|
|
""" |
|
|
Return the target candidate ids that correspond to the assistant candidate ids. |
|
|
Note that we have already the target ids for the prompt and we only need to find the target ids for the new tokens. |
|
|
Moreover, assistant ids of the original prompt does not necessarily appear in _assistant_to_target_input_ids. |
|
|
""" |
|
|
|
|
|
num_new_tokens = len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1] |
|
|
if num_new_tokens == 0: |
|
|
return target_input_ids |
|
|
else: |
|
|
|
|
|
last_candidate_ids = assistant_candidate_ids[0, -num_new_tokens:] |
|
|
if self.assistant_prune_lm_head: |
|
|
|
|
|
last_candidate_ids = self.assistant_overlap_token_ids[last_candidate_ids] |
|
|
transformed_slice = self._assistant_to_target_input_ids[last_candidate_ids] |
|
|
return torch.cat((target_input_ids, transformed_slice.unsqueeze(0)), dim=1) |
|
|
|
|
|
def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatTensor: |
|
|
""" |
|
|
Return the target logits that correspond to the assistant logits. |
|
|
""" |
|
|
|
|
|
target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], self.target_vocab_size) |
|
|
target_logits: torch.FloatTensor = torch.full( |
|
|
target_shape, self.FILTER_VALUE, device=self._assistant_model_device |
|
|
) |
|
|
|
|
|
assistant_indices_mask = self._assistant_to_target_input_ids != self.SUPPRESS_TOKEN_ID |
|
|
|
|
|
target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask] |
|
|
|
|
|
if self.assistant_prune_lm_head: |
|
|
target_logits[..., target_logits_supported_indices] = assistant_logits |
|
|
else: |
|
|
valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]] |
|
|
target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask] |
|
|
return target_logits |
|
|
|
|
|
|
|
|
class AssistantVocabTranslatorCache: |
|
|
""" |
|
|
Cache for `AssistantToTargetTranslator` instances. The instances are computed at |
|
|
pre-processing time, and this cache allows us to avoid recomputing them. |
|
|
""" |
|
|
|
|
|
_cache = weakref.WeakKeyDictionary() |
|
|
|
|
|
@classmethod |
|
|
def get_translator( |
|
|
cls, |
|
|
target_tokenizer: "PreTrainedTokenizerBase", |
|
|
assistant_tokenizer: "PreTrainedTokenizerBase", |
|
|
target_vocab_size: int, |
|
|
assistant_model: Optional["PreTrainedModel"] = None, |
|
|
assistant_prune_lm_head: bool = False, |
|
|
) -> AssistantToTargetTranslator: |
|
|
assistant_dict = cls._cache.get(target_tokenizer) |
|
|
if assistant_dict is None: |
|
|
assistant_dict = weakref.WeakKeyDictionary() |
|
|
cls._cache[target_tokenizer] = assistant_dict |
|
|
|
|
|
mapping = assistant_dict.get(assistant_tokenizer) |
|
|
if mapping is None: |
|
|
mapping = AssistantToTargetTranslator( |
|
|
target_tokenizer, |
|
|
assistant_tokenizer, |
|
|
target_vocab_size, |
|
|
assistant_model, |
|
|
assistant_prune_lm_head, |
|
|
) |
|
|
assistant_dict[assistant_tokenizer] = mapping |
|
|
|
|
|
return mapping |
|
|
|
|
|
@classmethod |
|
|
def cleanup(cls): |
|
|
""" |
|
|
Clean up dead references in the cache. |
|
|
This removes entries where either the target_tokenizer or assistant_tokenizer |
|
|
has been garbage collected. |
|
|
""" |
|
|
|
|
|
dead_keys = [key for key in cls._cache if key is None] |
|
|
for key in dead_keys: |
|
|
del cls._cache[key] |
|
|
|
|
|
|
|
|
for assistant_dict in cls._cache.values(): |
|
|
dead_keys = [key for key in assistant_dict if key is None] |
|
|
for key in dead_keys: |
|
|
del assistant_dict[key] |
|
|
|
|
|
|
|
|
class UniversalSpeculativeDecodingGenerator(AssistedCandidateGeneratorDifferentTokenizers): |
|
|
""" |
|
|
`CandidateGenerator` class to be used for Universal Speculative Decoding (USD): speculative decoding with different tokenizers |
|
|
for the assistant and main models. This class generates candidates through the use of a smaller model. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
assistant_model: "PreTrainedModel", |
|
|
target_tokenizer: "PreTrainedTokenizerBase", |
|
|
assistant_tokenizer: "PreTrainedTokenizerBase", |
|
|
generation_config: "GenerationConfig", |
|
|
model_kwargs: dict, |
|
|
atm_translator: AssistantToTargetTranslator, |
|
|
inputs_tensor: Optional[torch.Tensor] = None, |
|
|
logits_processor: Optional["LogitsProcessorList"] = None, |
|
|
): |
|
|
|
|
|
self._atm_translator = atm_translator |
|
|
super().__init__( |
|
|
input_ids, |
|
|
assistant_model, |
|
|
target_tokenizer, |
|
|
assistant_tokenizer, |
|
|
generation_config, |
|
|
model_kwargs, |
|
|
inputs_tensor, |
|
|
logits_processor, |
|
|
) |
|
|
|
|
|
self._target_seq_len_with_candidates: int = 0 |
|
|
self._prev_assistant_ids: Optional[torch.LongTensor] = None |
|
|
|
|
|
def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, Optional[torch.FloatTensor]]: |
|
|
""" |
|
|
Simplified version of get_candidates that uses the translator cache for token conversion. |
|
|
""" |
|
|
target_input_ids = input_ids.to(self.assistant_model.device) |
|
|
assistant_input_ids, num_added_tokens = self._prepare_assistant_input_ids(target_input_ids) |
|
|
min_new_tokens, max_new_tokens = self._calculate_new_tokens(target_input_ids) |
|
|
|
|
|
if max_new_tokens == 0: |
|
|
return input_ids, None |
|
|
|
|
|
self._update_past_and_masks(assistant_input_ids, num_added_tokens=num_added_tokens) |
|
|
generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens) |
|
|
|
|
|
|
|
|
generation_args["generation_config"].output_scores = True |
|
|
generation_args["generation_config"].return_dict_in_generate = True |
|
|
|
|
|
|
|
|
if self._atm_translator.logits_processors is not None: |
|
|
generation_args["logits_processor"] = self._atm_translator.logits_processors |
|
|
self._prev_assistant_ids, assistant_candidate_logits = self._generate_candidates(generation_args) |
|
|
|
|
|
|
|
|
target_candidate_ids = self._atm_translator.get_target_ids( |
|
|
assistant_input_ids, target_input_ids, self._prev_assistant_ids |
|
|
) |
|
|
self._target_seq_len_with_candidates = target_candidate_ids.shape[-1] |
|
|
target_candidate_logits = self._atm_translator.get_target_logits(assistant_candidate_logits) |
|
|
|
|
|
return target_candidate_ids, target_candidate_logits |
|
|
|
|
|
def _update_past_and_masks(self, assistant_input_ids: torch.LongTensor, num_added_tokens: int = 1) -> bool: |
|
|
if self._prev_assistant_ids is None: |
|
|
|
|
|
|
|
|
self.assistant_kwargs = _prepare_attention_mask( |
|
|
self.assistant_kwargs, assistant_input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder |
|
|
) |
|
|
return super()._update_past_and_masks(assistant_input_ids, num_added_tokens=num_added_tokens) |
|
|
|
|
|
def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> torch.LongTensor: |
|
|
""" |
|
|
Simplified token conversion that only processes new tokens. |
|
|
""" |
|
|
|
|
|
target_seq_len = target_input_ids.shape[-1] |
|
|
if self._target_seq_len_with_candidates == 0: |
|
|
new_token_count = target_seq_len |
|
|
else: |
|
|
new_token_count = 1 |
|
|
target_new_ids = target_input_ids[:, -new_token_count:] |
|
|
|
|
|
|
|
|
assistant_new_ids = None |
|
|
if self._target_seq_len_with_candidates > 0: |
|
|
|
|
|
assistant_new_ids = self._atm_translator.target_to_assistant_input_ids.get(target_new_ids[0].item()) |
|
|
if assistant_new_ids is None: |
|
|
target_new_text = self.target_tokenizer.batch_decode( |
|
|
target_new_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True |
|
|
) |
|
|
assistant_new_ids = self.assistant_tokenizer( |
|
|
target_new_text, add_special_tokens=False, return_tensors="pt" |
|
|
)["input_ids"].to(self.assistant_model.device) |
|
|
else: |
|
|
assistant_new_ids = torch.tensor([[assistant_new_ids]], device=self.assistant_model.device) |
|
|
|
|
|
|
|
|
if self._prev_assistant_ids is None: |
|
|
assistant_input_ids = assistant_new_ids |
|
|
else: |
|
|
tokens_to_remove = self._target_seq_len_with_candidates + 1 - target_seq_len |
|
|
|
|
|
if tokens_to_remove > 0: |
|
|
self._prev_assistant_ids = self._prev_assistant_ids[:, :-tokens_to_remove] |
|
|
assistant_input_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1) |
|
|
assistant_input_ids = assistant_input_ids.to(dtype=torch.long) |
|
|
self._atm_translator.unmap_input_ids() |
|
|
return assistant_input_ids, len(assistant_new_ids[0]) |
|
|
|
|
|
|
|
|
class PromptLookupCandidateGenerator(CandidateGenerator): |
|
|
""" |
|
|
`CandidateGenerator` class to be used for prompt lookup generation. This class generates candidates by looking up |
|
|
likely continuations in the provided prompt (input_ids) itself. |
|
|
Read the following blog post for more information: https://github.com/apoorvumang/prompt-lookup-decoding |
|
|
|
|
|
Args: |
|
|
eos_token_id (`torch.Tensor`, *optional*): |
|
|
The token id of the end of sequence token. |
|
|
num_output_tokens (`int`, *optional*, defaults to 10): |
|
|
The number of tokens to be output as candidate tokens. |
|
|
max_matching_ngram_size (`int`, *optional*, defaults to 2): |
|
|
The maximum ngram size to be considered for matching in the prompt |
|
|
max_length (`int`, *optional*, defaults to 20): |
|
|
The number of total maximum tokens that can be generated. For decoder-only models that includes the |
|
|
prompt length. Defaults to 20, which is the max length used as default in generation config. |
|
|
logits_processor (`LogitsProcessorList`, *optional*): |
|
|
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] |
|
|
used to modify the prediction scores of the language modeling head applied at each generation step. In |
|
|
prompt lookup assisted generation, they are not used to manipulate probabilities, but rather to find |
|
|
forbidden tokens (p = -inf) and block them from being valid candidates. |
|
|
vocab_size (`int`, *optional*): |
|
|
The size of the vocabulary. Required if `logits_processor` is provided. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
eos_token_id: Optional[torch.Tensor] = None, |
|
|
num_output_tokens: int = 10, |
|
|
max_matching_ngram_size: int = 2, |
|
|
max_length: int = 20, |
|
|
logits_processor: Optional["LogitsProcessorList"] = None, |
|
|
vocab_size: Optional[int] = None, |
|
|
): |
|
|
self.num_output_tokens = num_output_tokens |
|
|
self.max_matching_ngram_size = max_matching_ngram_size |
|
|
self.max_length = max_length |
|
|
self.eos_token_id = eos_token_id |
|
|
self.logits_processor = logits_processor |
|
|
self.vocab_size = vocab_size |
|
|
|
|
|
if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0: |
|
|
raise ValueError("Invalid max_matching_ngram_size or num_output_tokens") |
|
|
|
|
|
def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, Optional[torch.FloatTensor]]: |
|
|
""" |
|
|
Fetches the candidates to be tried for the current input. |
|
|
|
|
|
Args: |
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
|
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) |
|
|
|
|
|
Return: |
|
|
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried. |
|
|
""" |
|
|
bsz, input_length = input_ids.shape |
|
|
|
|
|
|
|
|
if self.max_length == input_length + 1: |
|
|
return input_ids, None |
|
|
|
|
|
chosen_ids = None |
|
|
match_found = False |
|
|
for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1): |
|
|
|
|
|
windows = input_ids.unfold(dimension=1, size=ngram_size, step=1) |
|
|
|
|
|
|
|
|
ngram_tensor = input_ids[0, -ngram_size:] |
|
|
|
|
|
|
|
|
matches = (windows == ngram_tensor).all(dim=2) |
|
|
|
|
|
|
|
|
match_indices = matches.nonzero(as_tuple=True)[1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for idx in match_indices: |
|
|
start_idx = idx + ngram_size |
|
|
end_idx = start_idx + self.num_output_tokens |
|
|
end_idx = min(end_idx, input_length, self.max_length) |
|
|
|
|
|
if start_idx < end_idx: |
|
|
chosen_ids = input_ids[0, start_idx:end_idx] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.logits_processor is not None: |
|
|
sequence_with_candidate = input_ids |
|
|
fake_input_logits = torch.ones( |
|
|
(bsz, self.vocab_size), device=input_ids.device, dtype=torch.float32 |
|
|
) |
|
|
for candidate_idx, new_candidate_token in enumerate(chosen_ids): |
|
|
fake_output_logits = self.logits_processor(sequence_with_candidate, fake_input_logits) |
|
|
fake_candidate_logits = fake_output_logits[0, new_candidate_token] |
|
|
|
|
|
if fake_candidate_logits in (-float("Inf"), torch.finfo(fake_candidate_logits.dtype).min): |
|
|
chosen_ids = chosen_ids[:candidate_idx] |
|
|
break |
|
|
else: |
|
|
sequence_with_candidate = torch.cat( |
|
|
(input_ids, chosen_ids[: candidate_idx + 1].unsqueeze(0)), dim=1 |
|
|
) |
|
|
|
|
|
if chosen_ids.shape[0] == 0: |
|
|
continue |
|
|
|
|
|
match_found = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask = isin_mps_friendly(chosen_ids, self.eos_token_id) |
|
|
match_indices_eos = torch.nonzero(mask) |
|
|
if match_indices_eos.numel() > 0: |
|
|
first_eos_index = match_indices_eos[0].item() |
|
|
chosen_ids = chosen_ids[:first_eos_index] |
|
|
break |
|
|
if match_found: |
|
|
break |
|
|
|
|
|
|
|
|
if not match_found or len(chosen_ids) == 0: |
|
|
return input_ids, None |
|
|
|
|
|
|
|
|
chosen_ids = chosen_ids.unsqueeze(0) |
|
|
candidate_input_ids = torch.cat((input_ids, chosen_ids), dim=1) |
|
|
|
|
|
return candidate_input_ids, None |
|
|
|
|
|
def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): |
|
|
""" |
|
|
Updates the candidate generation strategy based on the outcomes. |
|
|
|
|
|
Args: |
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
|
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) |
|
|
scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`): |
|
|
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using |
|
|
beam search or log softmax for each vocabulary token when using beam search |
|
|
num_matches (`int`): |
|
|
The number of matches between the candidate sequences and the model predictions. |
|
|
""" |
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
class EarlyExitCandidateGenerator(AssistedCandidateGenerator): |
|
|
""" |
|
|
`CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates |
|
|
candidates through the use of **the model itself**, exiting early. Can only be used with models that support early |
|
|
exit, e.g., `facebook/layerskip-llama3.2-1B`. |
|
|
|
|
|
Args: |
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
|
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) |
|
|
assistant_model (`PreTrainedModel`): |
|
|
The original model. This model must support early exit (i.e. is trained to compute logits in earlier |
|
|
layers). |
|
|
generation_config (`~generation.GenerationConfig`, *optional*): |
|
|
The generation configuration to be used as base parametrization for the generation call. |
|
|
logits_processor (`LogitsProcessorList`): |
|
|
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] |
|
|
used to modify the prediction scores of the language modeling head applied at each generation step. |
|
|
model_kwargs (`Dict`): |
|
|
The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant |
|
|
model as well. |
|
|
inputs_tensor (`torch.Tensor`, *optional*): |
|
|
The model input tensor. In encoder-decoder models, this is the encoder input. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
assistant_model: "PreTrainedModel", |
|
|
generation_config: "GenerationConfig", |
|
|
model_kwargs: dict, |
|
|
inputs_tensor: Optional[torch.Tensor] = None, |
|
|
logits_processor: Optional["LogitsProcessorList"] = None, |
|
|
): |
|
|
super().__init__( |
|
|
input_ids=input_ids, |
|
|
assistant_model=assistant_model, |
|
|
generation_config=generation_config, |
|
|
model_kwargs=model_kwargs, |
|
|
inputs_tensor=inputs_tensor, |
|
|
logits_processor=logits_processor, |
|
|
) |
|
|
|
|
|
|
|
|
self.assistant_early_exit = self.generation_config.assistant_early_exit |
|
|
self.generation_config.assistant_early_exit = None |
|
|
|
|
|
def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, Optional[torch.FloatTensor]]: |
|
|
|
|
|
base_model = getattr(self.assistant_model, self.assistant_model.base_model_prefix) |
|
|
original_num_hidden_layers = base_model.config.num_hidden_layers |
|
|
base_model.config.num_hidden_layers = self.assistant_early_exit |
|
|
candidate_ids, candidate_logits = super().get_candidates(input_ids) |
|
|
base_model.config.num_hidden_layers = original_num_hidden_layers |
|
|
return candidate_ids, candidate_logits |
|
|
|
|
|
|
|
|
def _prepare_attention_mask(model_kwargs: dict[str, Any], new_length: int, is_encoder_decoder: bool) -> dict[str, Any]: |
|
|
"""Expands or crops the model's mask for decoding purposes, to the defined length""" |
|
|
|
|
|
mask_key = "decoder_attention_mask" if is_encoder_decoder else "attention_mask" |
|
|
if mask_key not in model_kwargs: |
|
|
return model_kwargs |
|
|
|
|
|
mask = model_kwargs[mask_key] |
|
|
mask_length_diff = new_length - mask.shape[1] |
|
|
|
|
|
if mask_length_diff < 0: |
|
|
model_kwargs[mask_key] = mask[:, :mask_length_diff] |
|
|
elif mask_length_diff > 0: |
|
|
model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1) |
|
|
|
|
|
|
|
|
if "cross_attention_mask" in model_kwargs: |
|
|
|
|
|
cross_mask = model_kwargs["cross_attention_mask"] |
|
|
if mask_length_diff < 0: |
|
|
model_kwargs["cross_attention_mask"] = cross_mask[:, :mask_length_diff] |
|
|
elif mask_length_diff > 0: |
|
|
new_mask = cross_mask[:, -1:, :, :].repeat(1, mask_length_diff, 1, 1) |
|
|
model_kwargs["cross_attention_mask"] = torch.cat([cross_mask, new_mask], dim=1) |
|
|
elif "image_attention_mask" in model_kwargs: |
|
|
|
|
|
cross_mask = model_kwargs["image_attention_mask"] |
|
|
if mask_length_diff < 0: |
|
|
model_kwargs["image_attention_mask"] = cross_mask[:, :mask_length_diff] |
|
|
elif mask_length_diff > 0: |
|
|
new_mask = cross_mask[:, -1:, :].repeat(1, mask_length_diff, 1) |
|
|
model_kwargs["image_attention_mask"] = torch.cat([cross_mask, new_mask], dim=1) |
|
|
|
|
|
return model_kwargs |
|
|
|
|
|
|
|
|
def _prepare_token_type_ids(model_kwargs: dict[str, Any], new_length: int) -> dict[str, Any]: |
|
|
"""Expands or crops the model's token_type_ids for decoding purposes, to the defined length""" |
|
|
if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None: |
|
|
return model_kwargs |
|
|
|
|
|
token_type_ids = model_kwargs["token_type_ids"] |
|
|
final_token_type = token_type_ids[:, -1].unsqueeze(-1) |
|
|
type_length_diff = new_length - token_type_ids.shape[1] |
|
|
|
|
|
if type_length_diff < 0: |
|
|
token_type_ids = token_type_ids[:, :type_length_diff] |
|
|
elif type_length_diff > 0: |
|
|
token_type_copies = final_token_type.repeat(1, type_length_diff) |
|
|
model_kwargs["token_type_ids"] = torch.cat([model_kwargs["token_type_ids"], token_type_copies], dim=-1) |
|
|
return model_kwargs |
|
|
|