import copy
import os
from decimal import Decimal, ROUND_HALF_UP
from typing import Any, Callable, Dict, Optional, Tuple, Union, TYPE_CHECKING
import numpy as np
import torch
import torch.utils.checkpoint
import torch.utils.checkpoint
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from transformers import PreTrainedModel
from transformers.generation.configuration_utils import GenerationConfig, GenerationMode
from transformers.generation.logits_process import (
LogitsProcessorList,
SuppressTokensAtBeginLogitsProcessor,
SuppressTokensLogitsProcessor, )
from transformers.generation.logits_process import WhisperNoSpeechDetection
from transformers.generation.stopping_criteria import (
StoppingCriteriaList,
)
from transformers.generation.utils import GenerateNonBeamOutput, \
GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput, GenerateBeamOutput, GenerateBeamDecoderOnlyOutput, \
GenerateBeamEncoderDecoderOutput
from transformers.modeling_outputs import BaseModelOutput
from transformers.models.whisper.modeling_whisper import (
WhisperForConditionalGeneration,
)
from transformers.utils import logging
from .decoding import CTCRescorerLogitsProcessor, LogSoftmaxProcessor
from .utils import WhisperTimeStampLogitsProcessorCustom
if TYPE_CHECKING:
from transformers.generation.streamers import BaseStreamer
logging.set_verbosity_debug()
logger = logging.get_logger("transformers")
class DiCoWGenerationMixin(WhisperForConditionalGeneration):
def _prepare_encoder_decoder_kwargs_for_generation(
self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name, generation_config,
) -> Dict[str, Any]:
# pylint: disable=no-memberva
model_kwargs = super()._prepare_encoder_decoder_kwargs_for_generation(
inputs_tensor, model_kwargs, model_input_name, generation_config
)
if hasattr(generation_config, "ctc_weight") and generation_config.ctc_weight > 0:
self.encoder_logits = self.get_enc_logits(model_kwargs["encoder_outputs"].last_hidden_state)
return model_kwargs
def _prepare_decoder_input_ids_for_generation(
self,
batch_size: int,
model_input_name: str,
model_kwargs: Dict[str, torch.Tensor],
decoder_start_token_id: torch.Tensor,
device: torch.device = None,
) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]:
batch_size = model_kwargs['decoder_input_ids'].shape[0]
out = super()._prepare_decoder_input_ids_for_generation(
batch_size,
model_input_name,
model_kwargs,
decoder_start_token_id,
device,
)
return out
def prepare_kwargs_for_generate(self,
max_frames,
cur_bsz,
batch_idx_map,
seek,
kwargs,
attention_mask):
"""This method also prepares STNO masks and other kwargs for generation."""
seek_vad = seek // 2
input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
num_segment_frames = input_stride * self.config.max_source_positions
num_frames_vad = num_segment_frames // 2
max_frames_vad = max_frames // 2
seek_num_frames = (max_frames_vad - seek_vad).clamp(max=num_frames_vad)
stno_masks = []
for i in range(cur_bsz):
prev_i = batch_idx_map[i]
segment_input_slice = kwargs["stno_mask"][prev_i: prev_i + 1, :,
seek_vad[prev_i]: seek_vad[prev_i] + seek_num_frames[prev_i]]
if segment_input_slice.shape[-1] < num_frames_vad:
orig_len = segment_input_slice.shape[-1]
# pad to 1500 if necessary
segment_input_slice = torch.nn.functional.pad(
segment_input_slice, pad=(0, num_frames_vad - orig_len)
)
# set corresponding padding tokens to 1 in vad mask representing silence
segment_input_slice[0, 0, orig_len:] = 1.0
stno_masks.append(segment_input_slice)
kwargs["stno_mask"] = torch.cat(stno_masks, dim=0)
self.stno_mask_seek = kwargs["stno_mask"]
if self.config.use_enrollments and "enrollments" in kwargs:
for key in kwargs["enrollments"]:
kwargs["enrollments"][key] = kwargs["enrollments"][key][batch_idx_map]
if attention_mask is not None:
attention_mask = attention_mask[batch_idx_map]
if "labels" in kwargs:
kwargs['labels'] = kwargs["labels"][batch_idx_map]
kwargs['upp_labels'] = kwargs["upp_labels"][batch_idx_map]
return kwargs, attention_mask
def _retrieve_init_tokens(self, input_features, batch_size, generation_config, config, num_segment_frames, kwargs):
task = getattr(generation_config, "task", None)
language = getattr(generation_config, "language", None)
if "enrollments" in kwargs:
self.enrollments = kwargs["enrollments"]
forced_decoder_ids = generation_config.forced_decoder_ids if hasattr(generation_config, "forced_decoder_ids") else None
if forced_decoder_ids is not None:
if language is None and task is None and forced_decoder_ids[0][1] is None:
logger.warning_once(
"Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English."
"This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`."
)
elif hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None:
forced_decoder_ids = config.forced_decoder_ids
elif forced_decoder_ids is not None and language is not None:
logger.info(
f"You have passed language={language}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of language={language}."
)
forced_decoder_ids = None
if forced_decoder_ids is not None:
return forced_decoder_ids
init_tokens = super()._retrieve_init_tokens(input_features, batch_size, generation_config, config, num_segment_frames, kwargs)
del self.enrollments
return init_tokens
def detect_language(
self,
input_features: Optional[torch.FloatTensor] = None,
encoder_outputs: Optional[Union[torch.FloatTensor, BaseModelOutput]] = None,
generation_config: Optional[GenerationConfig] = None,
num_segment_frames: int = 3000,
) -> torch.Tensor:
"""
Detects language from log-mel input features or encoder_outputs
Parameters:
input_features (`torch.Tensor` of shape `(batch_size, feature_size, sequence_length)`, *optional*):
Float values of log-mel features extracted from the raw speech waveform. The raw speech waveform can be obtained by
loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] for details.
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
`generation_config` is not provided, the default will be used, which had the following loading
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
num_segment_frames (`int`, *optional*, defaults to 3000):
The number of log-mel frames the model expects
Return:
A `torch.LongTensor` representing the detected language ids.
"""
if input_features is None and encoder_outputs is None:
raise ValueError("You have to specify either `input_features` or `encoder_outputs`")
elif input_features is not None and encoder_outputs is not None:
raise ValueError("Make sure to specify only one of `input_features` or `encoder_outputs` - not both!")
elif input_features is not None:
inputs = {"input_features": input_features[:, :, :num_segment_frames]}
batch_size = input_features.shape[0]
elif encoder_outputs is not None:
inputs = {"encoder_outputs": encoder_outputs}
batch_size = (
encoder_outputs[0].shape[0] if isinstance(encoder_outputs, BaseModelOutput) else encoder_outputs[0]
)
generation_config = generation_config or self.generation_config
decoder_input_ids = (
torch.ones((batch_size, 1), device=self.device, dtype=torch.long)
* generation_config.decoder_start_token_id
)
with torch.no_grad():
""""""
if hasattr(self, "enrollments"):
inputs["enrollments"] = self.enrollments
logits = self(**inputs, decoder_input_ids=decoder_input_ids, use_cache=False,
stno_mask=self.stno_mask[:, :, :num_segment_frames // 2]).logits[:, -1]
""""""
non_lang_mask = torch.ones_like(logits[0], dtype=torch.bool)
non_lang_mask[list(generation_config.lang_to_id.values())] = False
logits[:, non_lang_mask] = -np.inf
lang_ids = logits.argmax(-1)
return lang_ids
def _get_logits_processor(
self,
generation_config: GenerationConfig,
input_ids_seq_length: Optional[int] = None,
encoder_input_ids: Optional[torch.LongTensor] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
logits_processor: Optional[LogitsProcessorList] = None,
device: Optional[str] = None,
model_kwargs: Optional[dict[str, Any]] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
) -> LogitsProcessorList:
# pylint: disable=no-member
gen_config_copy = copy.deepcopy(generation_config)
gen_config_copy.forced_decoder_ids = None
processors = super()._get_logits_processor(
gen_config_copy,
input_ids_seq_length,
encoder_input_ids,
prefix_allowed_tokens_fn,
logits_processor,
device,
model_kwargs,
negative_prompt_ids,
negative_prompt_attention_mask,
)
if hasattr(generation_config, "ctc_weight") and generation_config.ctc_weight > 0:
enc_logits = self.encoder_logits
if generation_config.num_beams <= 1:
processors.append(LogSoftmaxProcessor())
else:
enc_logits = enc_logits.repeat_interleave(generation_config.num_beams, dim=0)
self.ctc_rescorer = CTCRescorerLogitsProcessor(
enc_logits,
torch.full((enc_logits.shape[0],), fill_value=enc_logits.shape[1],
device=enc_logits.device),
enc_logits.shape[-1] - 1,
generation_config.pad_token_id,
generation_config.eos_token_id,
generation_config.decoder_start_token_id,
self.tokenizer,
0,
generation_config.ctc_weight,
generation_config.num_beams,
False,
)
processors.append(self.ctc_rescorer)
return processors
def _retrieve_logit_processors(self, generation_config, logits_processor, begin_index, num_beams, device):
if generation_config.return_timestamps is True:
""""""
timestamp_processor = WhisperTimeStampLogitsProcessorCustom(generation_config, begin_index=begin_index)
""""""
logits_processor = (
[timestamp_processor] if logits_processor is None else [timestamp_processor] + logits_processor
)
if generation_config.suppress_tokens is not None:
suppress_tokens_processor = SuppressTokensLogitsProcessor(generation_config.suppress_tokens, device=device)
logits_processor = (
[suppress_tokens_processor]
if logits_processor is None
else [suppress_tokens_processor] + logits_processor
)
generation_config.suppress_tokens = None
if generation_config.begin_suppress_tokens is not None:
begin_suppress_processor = SuppressTokensAtBeginLogitsProcessor(
generation_config.begin_suppress_tokens, begin_index=begin_index, device=device
)
logits_processor = (
[begin_suppress_processor]
if logits_processor is None
else [begin_suppress_processor] + logits_processor
)
generation_config.begin_suppress_tokens = None
if generation_config.no_speech_threshold is not None:
no_speech_detector = WhisperNoSpeechDetection(
no_speech_token=generation_config.no_timestamps_token_id - 1,
begin_index=begin_index,
scores_is_logprobs=num_beams > 1,
)
logits_processor = (
[no_speech_detector] if logits_processor is None else [no_speech_detector] + logits_processor
)
no_speech_detector.set_model(self)
return logits_processor
@staticmethod
def round_to_nearest_0_02(x):
d = Decimal(str(x)) # Use str(x) to preserve input precision
step = Decimal('0.02')
# Divide, round, multiply back
rounded = (d / step).to_integral_value(rounding=ROUND_HALF_UP) * step
return rounded
def _fix_timestamps_from_segmentation(self, sequences):
"""
Adjusts token sequences with global timestamps to fit within Whisper's 0–30s timestamp token range.
"""
# Get the token ID for the "<|0.00|>" timestamp used to detect dummy segments
first_timestamp_token = self.tokenizer.get_vocab()["<|0.00|>"]
empty_text_token = self.tokenizer.get_vocab()["Ġ"]
results = []
# Filter out segments that are either empty or consist only of the "<|0.00|>" token
for idx, sequence_segs in enumerate(sequences['segments']):
sequences['segments'][idx] = [
seg for seg in sequence_segs
if len(seg['tokens']) > 0 and (len(seg['tokens']) != 1 or seg['tokens'][0] != first_timestamp_token)
]
# Iterate over each group of segments
for idx, sequence_segs in enumerate(sequences['segments']):
result = []
prev_segment_end_time = None
correction = Decimal(0.0)
for i, seg in enumerate(sequence_segs):
# Round start and end times to nearest 0.02 seconds
start_time = self.round_to_nearest_0_02(seg['start'].item())
end_time = self.round_to_nearest_0_02(seg['end'].item())
tokens = seg['tokens']
# Determine which 30s window this segment falls into
current_block = (start_time + correction) // 30
if prev_segment_end_time is not None:
# We subtract a tiny epsilon from prev_segment_end_time.
# If prev ended exactly at 30.0, it belongs to block 0, not block 1.
# 30.0 // 30 = 1 (Wrong) | 29.999 // 30 = 0 (Correct)
prev_block = (prev_segment_end_time - Decimal("0.001")) // 30
num_dummies = current_block - prev_block - 1
# Insert (30, [], 30) marker if we're moving to a new block
if current_block > prev_block:
result.append((30, [empty_text_token], 30))
# Insert dummy segments to bridge skipped 30s blocks
for _ in range(int(num_dummies)):
result.append((0, [empty_text_token], 30))
else:
# For the first segment, add dummy blocks if it starts after 30s
for _ in range(int(start_time // 30)):
result.append((0, [empty_text_token], 30))
# Determine whether segment fits in one block or wraps to the next
if ((start_time + correction) // 30 == (end_time + correction) // 30):
# Segment fits within a single 30s window
result.append(((start_time + correction) % 30, tokens, (end_time + correction) % 30))
elif (end_time + correction) % 30 == 0:
result.append(((start_time + correction) % 30, tokens, 30))
# Important: reset correction if we landed exactly on the boundary
correction = Decimal(0.0)
else:
# Segment would wrap across a 30s boundary
new_seg_start = (correction + start_time) % 30
seg_duration = end_time - start_time
new_end_time = (end_time + correction) % 30
if seg_duration == 30.0:
if float(new_seg_start) % 30.0 == 0.0:
new_end_time = Decimal(30.0)
correction = Decimal(0.0)
else:
correction = Decimal(-0.02)
new_end_time += Decimal(correction)
else:
correction = Decimal(0.0)
result.append((new_seg_start, tokens, new_end_time))
# Update the previous segment's end time for next iteration
prev_segment_end_time = end_time + correction
# Convert result segments into a token sequence with proper timestamp formatting
encoded = self.tokenizer(
"".join([f"<|{seg[0]:.2f}|>{self.tokenizer.decode(seg[1])}<|{seg[2]:.2f}|>" for seg in result])
)['input_ids']
results.append(encoded)
# Pad all sequences to the same length for batching
sequences = pad_sequence(
[torch.tensor(res, device=sequences['sequences'].device) for res in results],
batch_first=True,
padding_value=self.tokenizer.pad_token_id
)
return sequences
@staticmethod
def _retrieve_segment(
seek_sequence,
seek_outputs,
time_offset,
timestamp_begin,
seek_num_frames,
time_precision,
time_precision_features,
input_stride,
prev_idx,
idx,
return_token_timestamps,
decoder_input_ids,
):
# find the predicted "end of segment" predictions of Whisper
# "end of segment" predictions occur whenever Whisper predicts a timestamp token
timestamp_tokens: torch.Tensor = seek_sequence.ge(timestamp_begin)
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
timestamp_segment_indices.add_(1)
token_timestamps = seek_outputs[idx]["token_timestamps"] if return_token_timestamps else []
idx_offset = decoder_input_ids.shape[-1]
device = seek_sequence.device
# If whisper predicted a "end of segment" via a timestep token, let's go ever each
# "end of segment" prediction and slice the decoding into segments accordingly
if len(timestamp_segment_indices) > 0:
# if the output contains two consecutive timestamp tokens
slices = timestamp_segment_indices.tolist()
segments = []
if single_timestamp_ending:
slices.append(len(seek_sequence))
else:
# we want to include the last timestamp token in the last segment to know it was no single ending
slices[-1] += 1
last_slice = 0
# Add each segment to list of all segments
for i, current_slice in enumerate(slices):
is_last_slice = i == len(slices) - 1
sliced_tokens = seek_sequence[last_slice:current_slice]
start_timestamp_pos = sliced_tokens[0] - timestamp_begin
idx_sliced_tokens = -1 if not is_last_slice or single_timestamp_ending else -2
end_timestamp_pos = sliced_tokens[idx_sliced_tokens] - timestamp_begin
segments.append(
{
"start": time_offset[prev_idx]
+ start_timestamp_pos.to(torch.float32 if device.type == "mps" else torch.float64)
* time_precision,
"end": time_offset[prev_idx]
+ end_timestamp_pos.to(torch.float32 if device.type == "mps" else torch.float64)
* time_precision,
"tokens": sliced_tokens,
"idxs": (idx_offset + last_slice, idx_offset + current_slice),
"result": seek_outputs[idx],
}
)
if return_token_timestamps:
segments[-1]["token_timestamps"] = (
token_timestamps[idx_offset + last_slice: idx_offset + current_slice] + time_offset[
prev_idx]
)
last_slice = current_slice
if single_timestamp_ending:
# single timestamp at the end means no speech after the last timestamp.
segment_offset = seek_num_frames[prev_idx]
else:
# otherwise, ignore the unfinished segment and seek to the last timestamp
# here we throw away all predictions after the last predicted "end of segment"
# since we are cutting right in the middle of an audio
last_timestamp_pos = seek_sequence[last_slice - 2].item() - timestamp_begin
segment_offset = last_timestamp_pos * input_stride
else:
# If whisper does not predict any "end of segment" token, then
# the whole decoding is considered a segment and we add it to the list of segments
timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()]
start_timestamp_pos = 0.0
last_timestamp_pos = seek_num_frames[prev_idx] // 2
skip = False
segment_offset = seek_num_frames[prev_idx]
if timestamps.numel() > 1:
start_timestamp_pos = timestamps[-2].item() - timestamp_begin
last_timestamp_pos = timestamps[-1].item() - timestamp_begin
elif timestamps.numel() == 1:
# no consecutive timestamps but it has a timestamp; use the last one.
start_timestamp_pos = timestamps[-1].item() - timestamp_begin
if start_timestamp_pos > 200:
# segment does not fit into decoding window, so we need to rollback
segment_offset = start_timestamp_pos * input_stride - 100 # timestamp might be inaccurate
skip = True
elif timestamps.numel() == 0 and len(seek_sequence) > 1:
# Decoding without timestamps, return output as it is
pass
else:
# empty sequence, or sequence w/o timestamps
skip = True
if skip:
segments = []
else:
segments = [
{
"start": time_offset[prev_idx] + start_timestamp_pos * time_precision,
"end": time_offset[prev_idx] + last_timestamp_pos * time_precision,
"tokens": seek_sequence,
"result": seek_outputs[idx],
}
]
if return_token_timestamps:
segments[-1]["token_timestamps"] = token_timestamps + time_offset[prev_idx]
segment_offset = seek_num_frames[prev_idx]
if segment_offset <= 0:
msg = f"Timestamps: {timestamps}, Segments: {segments}"
raise ValueError(f"Segment offset: {segment_offset} <= 0. This should not happen!\n{msg}")
return segments, segment_offset
def generate(
self,
generation_config: Optional[GenerationConfig] = None,
condition_on_prev_tokens: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
**kwargs,
):
if condition_on_prev_tokens:
raise NotImplementedError("Current version does not support conditioning")
gen_c, _ = self._prepare_generation_config(generation_config, **kwargs)
gen_mode = gen_c.get_generation_mode(assistant_model)
if gen_mode not in [GenerationMode.GREEDY_SEARCH, GenerationMode.BEAM_SEARCH]:
raise ValueError(
f"Provided generation mode {gen_mode} is not supported"
f" for WhisperForConditionalGeneration with joint CTC decoding")
if "stno_mask" in kwargs:
self.stno_mask = kwargs["stno_mask"]
output = super().generate(**kwargs, return_segments=True)
self.encoder_logits = None
if isinstance(output, dict):
output = self._fix_timestamps_from_segmentation(output)
return output
def generate_with_fallback(
self,
segment_input,
decoder_input_ids,
cur_bsz,
seek,
batch_idx_map,
temperatures,
generation_config,
logits_processor,
stopping_criteria,
prefix_allowed_tokens_fn,
synced_gpus,
return_token_timestamps,
do_condition_on_prev_tokens,
is_shortform,
batch_size,
attention_mask,
kwargs,
):
kwargs_local = copy.deepcopy(kwargs)
max_frames = attention_mask.sum(-1).cpu().to(torch.long)
kwargs_local, attention_mask = self.prepare_kwargs_for_generate(max_frames, cur_bsz, batch_idx_map, seek, kwargs_local, attention_mask)
seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens, model_output_type = super().generate_with_fallback(
segment_input,
decoder_input_ids,
cur_bsz,
seek,
batch_idx_map,
temperatures,
generation_config,
logits_processor,
stopping_criteria,
prefix_allowed_tokens_fn,
synced_gpus,
return_token_timestamps,
do_condition_on_prev_tokens,
is_shortform,
batch_size,
attention_mask,
kwargs_local,
)
self.stno_mask_seek = None
return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens, model_output_type
def _sample(
self,
input_ids: torch.LongTensor,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool = False,
streamer: Optional["BaseStreamer"] = None,
**model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
r"""
Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
logits_processor (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
stopping_criteria (`StoppingCriteriaList`):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
Return:
[`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
A `torch.LongTensor` containing the generated tokens (default behaviour) or a
[`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
`return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
`model.config.is_encoder_decoder=True`.
"""
# init values
pad_token_id = generation_config._pad_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
do_sample = generation_config.do_sample
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
# keep track of which sequences are already finished
batch_size, cur_len = input_ids.shape[:2]
this_peer_finished = False
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
model_forward = self.__call__
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
if compile_forward:
os.environ["TOKENIZERS_PARALLELISM"] = "0"
# If we use FA2 and a static cache, we cannot compile with fullgraph
if self.config._attn_implementation == "flash_attention_2":
# only raise warning if the user passed an explicit compile-config
if generation_config.compile_config is not None and generation_config.compile_config.fullgraph:
logger.warning_once(
"When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as "
"FA2 introduces graph breaks. We overrode the option with `fullgraph=False`."
)
generation_config.compile_config.fullgraph = False
model_forward = self.get_compiled_call(generation_config.compile_config)
if generation_config.prefill_chunk_size is not None:
model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
is_prefill = False
else:
is_prefill = True
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
if is_prefill:
outputs = self(**model_inputs, return_dict=True)
is_prefill = False
else:
outputs = model_forward(**model_inputs, return_dict=True)
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if synced_gpus and this_peer_finished:
continue
# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_logits:
raw_logits += (next_token_logits,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
# token selection
if do_sample:
probs = nn.functional.softmax(next_token_scores, dim=-1)
# TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(next_token_scores, dim=-1)
# finished sentences should have their next token be a padding token
if has_eos_stopping_criteria:
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
""""""
# Based on the next tokens select the ctc prev states and scores
if hasattr(self, "ctc_rescorer"):
self.ctc_rescorer.update_state(next_tokens, torch.arange(next_tokens.shape[0]))
""""""
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0
cur_len += 1
# This is needed to properly delete outputs.logits which may be very large for first iteration
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
del outputs
if streamer is not None:
streamer.end()
if return_dict_in_generate:
if self.config.is_encoder_decoder:
return GenerateEncoderDecoderOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return GenerateDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return input_ids
def _beam_search(
self,
input_ids: torch.LongTensor,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool,
**model_kwargs,
) -> Union[GenerateBeamOutput, torch.LongTensor]:
r"""
Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
If it's the first time you're diving into Beam Search, we recommend you read the following blog post:
https://huggingface.co/blog/how-to-generate (especially the beam search section).
You can recompute the sequence scores from the individual scores using the `compute_transition_scores` function
(https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationMixin.compute_transition_scores)
Parameters:
input_ids (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`):
The sequence used as a prompt for the generation.
logits_processor (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
stopping_criteria (`StoppingCriteriaList`:
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
Return:
[`generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
[`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
`return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if
`model.config.is_encoder_decoder=True`.
"""
# 1. init beam_search values
pad_token_id = generation_config._pad_token_tensor
eos_token_id = generation_config._eos_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
do_sample = generation_config.do_sample
early_stopping = generation_config.early_stopping
length_penalty = generation_config.length_penalty
max_length = generation_config.max_length
num_beams = generation_config.num_beams
num_return_sequences = generation_config.num_return_sequences
batch_size_unflattened, cur_len = input_ids.shape[:2]
batch_size = batch_size_unflattened // num_beams
# TODO (joao): standardize special cases
if self.__class__.__name__ == "MoshiDepthDecoder":
vocab_size = self.config.audio_vocab_size
elif self.__class__.__name__ == "ImageGPTForCausalImageModeling":
vocab_size = self.get_output_embeddings().out_features
else:
vocab_size = self.config.get_text_config().vocab_size
decoder_prompt_len = cur_len
this_peer_finished = False
# At each beam search step, we want to keep top K [K = (number of EOS tokens + 1) * `num_beams`] candidates
# with the highest log-probabilities, or sample K continuations without replacement. We gather the top K
# (as opposed to `num_beams`, or any number lower than K) so that we have at least `num_beams` sequences
# non-finished to continue the live beam search, in case the top `num_beams` all select an EOS token.
n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
beams_to_keep = max(2, 1 + n_eos_tokens) * num_beams
top_num_beam_mask = torch.cat(
(torch.ones((num_beams), dtype=torch.bool), torch.zeros((beams_to_keep - num_beams), dtype=torch.bool)),
dim=0,
).to(input_ids.device)
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
# (joao) feature lost in the refactor. Probably won't implement, hurts readability with minimal gains (there
# are newer low-memory alternatives like the offloaded cache)
sequential = generation_config.low_memory
if sequential:
raise ValueError(
"`low_memory=True` is not supported after the beam search refactor. Please check the discussion in "
"#35802 *after the PR got merged*, and add a comment there if your questions are not yet answered."
)
# 2. init output tuples
all_scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
beam_indices = () if (return_dict_in_generate and output_logits) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
# 3. init running tensors and static-shaped placeholders
# per batch, beam-item holding current token in loop and completed sequences
output_fill_value = pad_token_id or eos_token_id[0] if eos_token_id is not None else -1
running_sequences = torch.full(
(batch_size, num_beams, max_length),
fill_value=output_fill_value,
dtype=torch.int64,
device=input_ids.device,
)
running_sequences[:, :, :cur_len] = self._unflatten_beam_dim(input_ids, batch_size, num_beams)
sequences = running_sequences.detach().clone()
# per batch, beam-item score, logprobs
# initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
# of the first beam are considered to avoid sampling the exact same tokens across all beams.
running_beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
running_beam_scores[:, 1:] = -1e9
beam_scores = torch.full((batch_size, num_beams), fill_value=-1e9, dtype=torch.float, device=input_ids.device)
# per batch, beam-item state bit indicating if sentence has finished.
is_sent_finished = torch.zeros((batch_size, num_beams), dtype=torch.bool, device=input_ids.device)
# per batch state bit indicating if there is a possibility to improve the best finished sentence.
is_early_stop_heuristic_unsatisfied = torch.ones((batch_size, 1), dtype=torch.bool, device=input_ids.device)
# per batch, beam-item state bit indicating if there are valid continuations.
next_token_hits_stopping_criteria = torch.zeros(
(batch_size, num_beams), dtype=torch.bool, device=input_ids.device
)
# per batch selected beam indices
running_beam_indices = torch.full(
(batch_size, num_beams, max_length - cur_len), fill_value=-1, dtype=torch.int32, device=input_ids.device
)
beam_indices = running_beam_indices.detach().clone()
# 4. run the generation loop
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# a. Forward current tokens, obtain the logits
flat_running_sequences = self._flatten_beam_dim(running_sequences[:, :, :cur_len])
model_inputs = self.prepare_inputs_for_generation(flat_running_sequences, **model_kwargs)
# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
model_outputs = self(**model_inputs, return_dict=True)
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(
model_outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if synced_gpus and this_peer_finished:
continue
# Copy is needed to avoid keeping a hanging ref
logits = model_outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
# b. Compute log probs -- get log probabilities from logits, process logits with processors (*e.g.*
# `temperature`, ...), and add new logprobs to existing running logprobs scores.
log_probs = nn.functional.log_softmax(logits, dim=-1)
log_probs = logits_processor(flat_running_sequences, log_probs)
# Store logits, attentions and hidden_states when required
if return_dict_in_generate:
if output_logits:
raw_logits += (logits.clone(),)
if return_dict_in_generate and output_scores:
all_scores += (log_probs.clone(),)
if output_attentions:
decoder_attentions += (
(model_outputs.decoder_attentions,)
if self.config.is_encoder_decoder
else (model_outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (model_outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(model_outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (model_outputs.hidden_states,)
)
# This is needed to properly delete logits which may be very large for first iteration
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
del model_outputs
log_probs = self._unflatten_beam_dim(log_probs, batch_size, num_beams)
log_probs = log_probs + running_beam_scores[:, :, None]
log_probs = torch.reshape(log_probs, (batch_size, num_beams * vocab_size))
# c. Retrieve top-K continuations, i.e. select the next token (greedy or sampling) and then keep the best
# continuations among all beams based on the accumulated scores.
topk_log_probs, topk_running_sequences, topk_running_beam_indices = self._get_top_k_continuations(
accumulated_log_probs=log_probs,
running_sequences=running_sequences,
running_beam_indices=running_beam_indices,
cur_len=cur_len,
decoder_prompt_len=decoder_prompt_len,
do_sample=do_sample,
beams_to_keep=beams_to_keep,
num_beams=num_beams,
vocab_size=vocab_size,
batch_size=batch_size,
)
# d. Check which running sequences have finished
next_token_hits_stopping_criteria = stopping_criteria(
self._flatten_beam_dim(topk_running_sequences[:, :, : cur_len + 1]), # remove unfilled token indexes
all_scores,
)
next_token_hits_stopping_criteria = self._unflatten_beam_dim(
next_token_hits_stopping_criteria, batch_size, beams_to_keep
)
# e. Get the non-finished running `num_beams` sequences for the next generation step
running_sequences, running_beam_scores, running_beam_indices = self._get_running_beams_for_next_iteration(
topk_log_probs=topk_log_probs,
topk_running_sequences=topk_running_sequences,
topk_running_beam_indices=topk_running_beam_indices,
next_token_hits_stopping_criteria=next_token_hits_stopping_criteria,
num_beams=num_beams,
)
# f. Update the completed beams if a new high score in a finished sequence is found
sequences, beam_scores, beam_indices, is_sent_finished = self._update_finished_beams(
sequences=sequences,
topk_running_sequences=topk_running_sequences,
beam_scores=beam_scores,
topk_log_probs=topk_log_probs,
beam_indices=beam_indices,
topk_running_beam_indices=topk_running_beam_indices,
is_early_stop_heuristic_unsatisfied=is_early_stop_heuristic_unsatisfied,
is_sent_finished=is_sent_finished,
next_token_hits_stopping_criteria=next_token_hits_stopping_criteria,
top_num_beam_mask=top_num_beam_mask,
num_beams=num_beams,
cur_len=cur_len,
decoder_prompt_len=decoder_prompt_len,
length_penalty=length_penalty,
early_stopping=early_stopping,
)
# g. Prepare remaining data for the next iteration, including computing the stopping condition for
# beam search as a whole (as opposed to individual beams, i.e. `stopping_criteria`)
beam_idx = None
# pluck the cache from the beam indices that will be used in the next iteration
# NOTE: we need to check if `self._reorder_cache` exists for special models like RAG, RecurrentGemma etc.
if model_kwargs.get("past_key_values", None) is not None:
beam_idx = self._flatten_beam_dim(running_beam_indices[..., cur_len - decoder_prompt_len])
if hasattr(self, "_reorder_cache"):
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
else:
model_kwargs["past_key_values"].reorder_cache(beam_idx)
if hasattr(self, "ctc_rescorer"):
self.ctc_rescorer.update_state(running_sequences.flatten(0,1)[:, cur_len], beam_idx)
cur_len = cur_len + 1
is_early_stop_heuristic_unsatisfied = self._check_early_stop_heuristic(
is_early_stop_heuristic_unsatisfied=is_early_stop_heuristic_unsatisfied,
running_beam_scores=running_beam_scores,
beam_scores=beam_scores,
is_sent_finished=is_sent_finished,
cur_len=cur_len,
max_length=max_length,
decoder_prompt_len=decoder_prompt_len,
early_stopping=early_stopping,
length_penalty=length_penalty,
)
this_peer_finished = not self._beam_search_has_unfinished_sequences(
is_early_stop_heuristic_unsatisfied,
is_sent_finished,
next_token_hits_stopping_criteria,
early_stopping,
)
# 5. prepare outputs
# Take best beams for each batch (the score is sorted in descending order)
sequences = self._flatten_beam_dim(sequences[:, :num_return_sequences, :])
beam_scores = self._flatten_beam_dim(beam_scores[:, :num_return_sequences])
beam_indices = self._flatten_beam_dim(beam_indices[:, :num_return_sequences, :])
# Crop the static-shaped tensors to the actual size.
# `beam_indices` is initialized with -1s, and is updated with the beam index of the generated token at each
# step. We can use it to detect the generated length, which may be != `cur_len` (e.g. selected beam is from a
# previous decoding iteration)
max_generated_length = ((beam_indices + 1).bool()).sum(dim=1).max()
output_length = decoder_prompt_len + max_generated_length
sequences = sequences[:, :output_length]
beam_indices = beam_indices[:, :max_generated_length]
if return_dict_in_generate:
if not output_scores:
beam_scores = None
if self.config.is_encoder_decoder:
return GenerateBeamEncoderDecoderOutput(
sequences=sequences,
sequences_scores=beam_scores,
scores=all_scores,
logits=raw_logits,
beam_indices=beam_indices,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return GenerateBeamDecoderOnlyOutput(
sequences=sequences,
sequences_scores=beam_scores,
scores=all_scores,
logits=raw_logits,
beam_indices=beam_indices,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return sequences