| """ |
| This is a self-contained and flexible beam search implementation adapted from |
| AllenNLP's beam search: https://github.com/allenai/allennlp/blob/main/allennlp/nn/beam_search.py |
| """ |
|
|
| import copy |
| import warnings |
| from abc import abstractmethod |
| from inspect import signature |
| from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, cast |
|
|
| import torch |
|
|
| __all__ = [ |
| "Sampler", |
| "DeterministicSampler", |
| "MultinomialSampler", |
| "TopKSampler", |
| "TopPSampler", |
| "GumbelSampler", |
| "FinalSequenceScorer", |
| "SequenceLogProbabilityScorer", |
| "LengthNormalizedSequenceLogProbabilityScorer", |
| "Constraint", |
| "RepeatedNGramBlockingConstraint", |
| "BeamSearch", |
| ] |
|
|
| StateType = Dict[str, torch.Tensor] |
| StepFunctionTypeWithTimestep = Callable[[torch.Tensor, StateType, int], Tuple[torch.Tensor, StateType]] |
| StepFunctionTypeNoTimestep = Callable[[torch.Tensor, StateType], Tuple[torch.Tensor, StateType]] |
|
|
| StepFunctionType = TypeVar("StepFunctionType", StepFunctionTypeWithTimestep, StepFunctionTypeNoTimestep) |
| """ |
| The type of step function that can be passed to [`BeamSearch.search`](#search). |
| |
| This can either be [`StepFunctionTypeWithTimestep`](#stepfunctiontypewithtimestep) |
| or [`StepFunctionTypeNoTimestep`](#stepfunctiontypenotimestep). |
| """ |
|
|
| ConstraintStateType = List[List[Dict[str, Any]]] |
|
|
|
|
| class Sampler: |
| """ |
| An abstract class that can be used to sample candidates (either nodes or beams) |
| within `BeamSearch`. |
| |
| A `Sampler` just has three methods, `init_state()`, `sample_nodes()` and `sample_beams()`. |
| |
| `init_state()` takes three arguments: |
| |
| - a tensor of starting log probs with shape `(batch_size,, num_classes)`, |
| - the batch size, an int, |
| - and the number of classes, also an int. |
| |
| It returns a state dictionary with any state tensors needed for subsequent |
| calls to `sample_nodes()` and `sample_beams()`. |
| |
| By default this method just returns an empty dictionary. |
| |
| Both `sample_nodes()` and `sample_beams()` should take three arguments: |
| |
| - tensor of normalized log probabilities with shape `(batch_size, num_examples)`, |
| - an integer representing the number of samples to take for each example in the batch, |
| - and a state dictionary which could contain any tensors needed for the `Sampler` to keep |
| track of state. |
| |
| For `sample_nodes()`, `num_examples = num_classes`, but for `sample_beams`, |
| `num_examples = beam_size * per_node_beam_size`. |
| |
| The return value should be a tuple containing: |
| |
| - a tensor of log probabilities of the sampled examples with shape `(batch_size, num_samples)`, |
| - a tensor of indices of the sampled examples with shape `(batch_size, num_samples)`, |
| - and the updated state dictionary. |
| |
| A default implementation of `sample_beams` is provided, which just deterministically |
| picks the `k` examples with highest log probability. |
| """ |
|
|
| def init_state( |
| self, start_class_log_probabilities: torch.Tensor, batch_size: int, num_classes: int |
| ) -> StateType: |
| del start_class_log_probabilities, batch_size, num_classes |
| return {} |
|
|
| @abstractmethod |
| def sample_nodes( |
| self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType |
| ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: |
| raise NotImplementedError |
|
|
| def sample_beams( |
| self, log_probs: torch.Tensor, beam_size: int, state: StateType |
| ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: |
| del state |
| selected_log_probs, selected_indices = torch.topk(log_probs, beam_size, dim=-1) |
| return selected_log_probs, selected_indices, {} |
|
|
|
|
| class DeterministicSampler(Sampler): |
| """ |
| A `Sampler` that just deterministically returns the `k` nodes or beams with highest |
| log probability. |
| """ |
|
|
| def sample_nodes( |
| self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType |
| ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: |
| del state |
| selected_log_probs, selected_indices = torch.topk(log_probs, per_node_beam_size, dim=-1) |
| return selected_log_probs, selected_indices, {} |
|
|
|
|
| class MultinomialSampler(Sampler): |
| """ |
| A `Sampler` which samples nodes from the given multinomial distribution. Beams are sampled |
| in the default, non-deterministic way. |
| |
| :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature` |
| above 1.0 produces a flatter probability distribution. |
| :param with_replacement: Whether to sample with replacement. |
| |
| """ |
|
|
| def __init__( |
| self, |
| temperature: float = 1.0, |
| with_replacement: bool = False, |
| ) -> None: |
| self.temperature = temperature |
| self.with_replacement = with_replacement |
|
|
| def sample_nodes( |
| self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType |
| ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: |
| if self.temperature != 1.0: |
| _probabilities = torch.nn.functional.softmax(log_probs / self.temperature, dim=-1) |
| else: |
| _probabilities = log_probs.exp() |
|
|
| selected_indices = torch.multinomial(_probabilities, per_node_beam_size, replacement=self.with_replacement) |
|
|
| return torch.gather(log_probs, 1, selected_indices), selected_indices, state |
|
|
|
|
| class TopKSampler(Sampler): |
| """ |
| A `Sampler` which redistributes the probability mass function for nodes among the |
| top `k` choices, then samples from that subset after re-normalizing the probabilities. |
| |
| Beams are sampled in the default, deterministic way. |
| |
| :param k: The number of top choices to be selected from. |
| :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature` |
| above 1.0 produces a flatter probability distribution. |
| :param with_replacement: If set to `True`, samples will be selected with replacement from the top k choices. |
| """ |
|
|
| def __init__( |
| self, |
| k: int = 1, |
| temperature: float = 1.0, |
| with_replacement: bool = False, |
| ): |
| self.k = k |
| self.temperature = temperature or 1.0 |
| self.with_replacement = with_replacement |
|
|
| def sample_nodes( |
| self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType |
| ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: |
| if not per_node_beam_size <= self.k <= log_probs.size()[1]: |
| raise ValueError( |
| "k must be a postive integer no less than per_node_beam_size and no greater than vocabulary size" |
| ) |
|
|
| |
| top_k_log_probs, top_k_indices = log_probs.topk(self.k, dim=-1) |
|
|
| |
| |
| if self.temperature != 1.0: |
| top_k_log_probs = top_k_log_probs / self.temperature |
|
|
| |
| |
| normalized_top_k_probs = torch.nn.functional.softmax(top_k_log_probs, dim=-1) |
|
|
| |
| |
| |
| sampled_indices = torch.multinomial( |
| normalized_top_k_probs, per_node_beam_size, replacement=self.with_replacement |
| ) |
|
|
| |
| |
| indices = top_k_indices.gather(-1, sampled_indices) |
|
|
| return log_probs.gather(1, indices), indices, state |
|
|
|
|
| class TopPSampler(Sampler): |
| """ |
| A `Sampler` which redistributes the probability mass function for nodes among |
| the top choices with a cumulative probability of at least `p`, then samples from that subset |
| after re-normalizing the probabilities. |
| |
| Beams are sampled in the default, deterministic way. |
| |
| :param p: |
| The cumulative probability cutoff threshold. A higher value of `p` will result in more possible |
| examples to sample from. If `with_replacement` is `False` and the number of possible samples is |
| insufficient to sample without replacement from when calling `sample_nodes`, then the top |
| `per_node_beam_size` examples will be chosen. |
| :param temperature: |
| A `temperature` below 1.0 produces a sharper probability distribution and a `temperature` |
| above 1.0 produces a flatter probability distribution. |
| :param with_replacement: |
| If set to `True`, samples will be selected with replacement from the top choices. |
| |
| """ |
|
|
| def __init__( |
| self, |
| p: float = 0.9, |
| temperature: float = 1.0, |
| with_replacement: bool = False, |
| ): |
| if p < 0.0 or p > 1.0: |
| raise ValueError("p must be a positive float no greater than 1.0") |
| self.p = p |
| self.temperature = temperature or 1.0 |
| self.with_replacement = with_replacement |
|
|
| def sample_nodes( |
| self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType |
| ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: |
| if not per_node_beam_size <= log_probs.size()[1]: |
| raise ValueError("per_node_beam_size cannot be greater than vocabulary size") |
|
|
| |
| if self.temperature != 1.0: |
| _log_probs = torch.nn.functional.log_softmax(log_probs / self.temperature, dim=-1) |
| else: |
| _log_probs = log_probs |
|
|
| |
| log_probs_descending, sorting_indices = torch.sort(_log_probs, descending=True) |
|
|
| |
| probabilities_descending = log_probs_descending.exp() |
| probabilities_summed = torch.cumsum(probabilities_descending, dim=-1) |
|
|
| |
| |
| exclusion_mask = probabilities_summed >= self.p |
|
|
| |
| exclusion_mask[..., 1:] = exclusion_mask[..., :-1].clone() |
| exclusion_mask[..., 0] = False |
|
|
| |
| if not self.with_replacement: |
| exclusion_mask[..., :per_node_beam_size] = False |
|
|
| log_probs_descending[exclusion_mask] = torch.finfo(log_probs.dtype).min |
|
|
| |
| |
| filtered_probabilities = torch.nn.functional.softmax(log_probs_descending, dim=-1) |
|
|
| |
| |
| |
| sampled_indices = torch.multinomial( |
| filtered_probabilities, per_node_beam_size, replacement=self.with_replacement |
| ) |
|
|
| |
| |
| selected_indices = sorting_indices.gather(-1, sampled_indices) |
|
|
| |
| |
| return torch.gather(log_probs, 1, selected_indices), selected_indices, state |
|
|
|
|
| class GumbelSampler(Sampler): |
| """ |
| A `Sampler` which uses the Gumbel-Top-K trick to sample without replacement. See |
| [*Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for Sampling |
| Sequences Without Replacement*, W Kool, H Van Hoof and M Welling, 2010] |
| (https://api.semanticscholar.org/CorpusID:76662039). |
| |
| :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature` |
| above 1.0 produces a flatter probability distribution. |
| """ |
|
|
| def __init__(self, temperature: float = 1.0): |
| self.temperature = temperature |
|
|
| def init_state( |
| self, start_class_log_probabilities: torch.Tensor, batch_size: int, num_classes: int |
| ) -> StateType: |
| |
| zeros = start_class_log_probabilities.new_zeros((batch_size, num_classes)) |
|
|
| |
| G_phi_S = self.gumbel_with_max(start_class_log_probabilities, zeros) |
|
|
| return {"G_phi_S": G_phi_S} |
|
|
| def sample_nodes( |
| self, |
| log_probs: torch.Tensor, |
| per_node_beam_size: int, |
| state: StateType, |
| ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: |
| |
| |
| if self.temperature != 1.0: |
| _log_probs = torch.nn.functional.log_softmax(log_probs / self.temperature, dim=-1) |
| else: |
| _log_probs = log_probs |
|
|
| |
| phi_S = state["phi_S"] |
|
|
| |
| phi_S = phi_S.unsqueeze(-1).expand_as(_log_probs) |
|
|
| |
| phi_S_new = phi_S + _log_probs |
|
|
| |
| G_phi_S = state["G_phi_S"].unsqueeze(-1) |
|
|
| |
| G_phi_S_new = self.gumbel_with_max(phi_S_new, G_phi_S) |
|
|
| |
| |
| |
|
|
| |
| top_G_phi_S_new, top_indices = torch.topk(G_phi_S_new, per_node_beam_size, dim=-1) |
|
|
| |
| top_log_probs = log_probs.gather(1, top_indices) |
|
|
| return top_log_probs, top_indices, {"G_phi_S": top_G_phi_S_new} |
|
|
| def sample_beams( |
| self, |
| log_probs: torch.Tensor, |
| beam_size: int, |
| state: StateType, |
| ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: |
| """ |
| Returns the beams with the highest perturbed log probabilities. |
| """ |
| |
|
|
| batch_size = log_probs.size()[0] |
|
|
| |
| G_phi_S = state["G_phi_S"] |
|
|
| |
| G_phi_S = G_phi_S.reshape_as(log_probs) |
|
|
| |
| G_phi_S_new, selected_indices = torch.topk(G_phi_S, beam_size, dim=-1) |
|
|
| |
| selected_log_probs = log_probs.gather(1, selected_indices) |
|
|
| |
| |
| selected_log_probs, sort_indices = selected_log_probs.sort(dim=-1, descending=True) |
| selected_indices = selected_indices.gather(1, sort_indices) |
| G_phi_S_new = G_phi_S_new.gather(1, sort_indices) |
|
|
| |
| G_phi_S_new = G_phi_S_new.reshape(batch_size * beam_size) |
|
|
| |
| phi_S = selected_log_probs.reshape(batch_size * beam_size) |
|
|
| return selected_log_probs, selected_indices, {"G_phi_S": G_phi_S_new, "phi_S": phi_S} |
|
|
| def gumbel(self, phi) -> torch.Tensor: |
| """ |
| Sample `Gumbel(phi)`. |
| |
| `phi` should have shape `(batch_size, num_classes)`. |
| """ |
| return -torch.log(-torch.log(torch.rand_like(phi))) + phi |
|
|
| def gumbel_with_max(self, phi, T) -> torch.Tensor: |
| """ |
| Sample `Gumbel(phi)` conditioned on the maximum value being equal to `T`. |
| |
| `phi` should have shape `(batch_size, num_classes)` and `T` should have |
| shape `(batch_size, 1)`. |
| """ |
| |
| G_phi = self.gumbel(phi) |
|
|
| |
| |
| Z, _ = G_phi.max(dim=-1) |
|
|
| |
| v = T - G_phi + torch.log1p(-torch.exp(G_phi - Z.unsqueeze(-1))) |
|
|
| |
| return T - torch.nn.functional.relu(v) - torch.log1p(torch.exp(-v.abs())) |
|
|
|
|
| class FinalSequenceScorer: |
| """ |
| An abstract class that can be used to score the final generated sequences found |
| by beam search. Given the predicted sequences and the corresponding log probabilities of |
| those sequences, the class calculates and returns the final score of the sequences. |
| |
| The default implementation scores the sequences using the sum of the log probabilities of |
| the sequence, which is passed as input. |
| """ |
|
|
| @abstractmethod |
| def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor: |
| """ |
| Score the final predictions found by beam search. |
| Returns a tensor of the final sequence scores of shape `(batch_size, beam_size)`. |
| |
| :param predictions: A tensor containing the initial predictions with shape `(batch_size, beam_size, max_steps)`. |
| :param log_probabilities: A tensor containing the log probabilities of the sequence, defined as the sum |
| of the log probabilities per token, with shape `(batch_size, beam_size)`. |
| :param end_index: The index of the end symbol. |
| |
| """ |
| raise NotImplementedError |
|
|
|
|
| class SequenceLogProbabilityScorer(FinalSequenceScorer): |
| """ |
| A :class:`FinalSequenceScorer` which scores the sequences by the sum of the log probabilities |
| across the sequence's tokens. |
| """ |
|
|
| def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor: |
| del predictions, end_index |
| |
| |
| return log_probabilities |
|
|
|
|
| class LengthNormalizedSequenceLogProbabilityScorer(FinalSequenceScorer): |
| """ |
| A :class:`FinalSequenceScorer` which scores the sequences by the average log probability of the |
| tokens in the sequence. It optionally includes a length penalty which promotes |
| or demotes sequences based on their lengths. The final score for a sequence will |
| be `(sequence_log_probability) / (sequence_length ** length_penalty)`. The sequence length |
| here includes the end token. |
| |
| :param length_penalty: The length penalty to use. A value of 1.0 means no length penalty is used. |
| A value > 1.0 favors longer sequences, and < 1.0 favors shorter sequences. |
| """ |
|
|
| def __init__(self, length_penalty: float = 1.0): |
| super().__init__() |
| self.length_penalty = length_penalty |
|
|
| def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor: |
| |
| lengths = (predictions != end_index).long().sum(dim=2) |
|
|
| |
| |
| |
| |
| is_end_token = predictions[:, :, -1] == end_index |
| lengths += is_end_token.long() |
|
|
| |
| average_log_probs = log_probabilities / (lengths**self.length_penalty) |
| return average_log_probs |
|
|
|
|
| class Constraint: |
| """ |
| An abstract class that can be used to enforce constraints on the output predictions |
| by manipulating the class log probabilities during beam search. |
| |
| A `Constraint` just has three methods that need to be implemented by subclasses: |
| `init_state()`, `apply()` and `_update_state()`. |
| |
| `init_state()` takes one argument: |
| |
| - the batch size, an int |
| |
| It returns a constraint state, which is a nested list of dictionaries, with any state needed for subsequent |
| calls to `apply()` and `update_state()`. The length of the outer list should be equal to `batch_size`. |
| Each inner list should be of length 1. |
| |
| `apply()` takes two arguments: |
| |
| - the constraint state, which is a nested list of dictionaries. The length of the outer list is `batch_size` |
| and the length of each inner list is `beam_size` except on the first time `apply()` is called when it is 1. |
| - `class_log_probabilities`, a tensor of shape `(batch_size, beam_size, num_classes)` that contains the |
| log probabilities for the classes during search. The first time `apply()` is called, `beam_size = 1`. |
| |
| The `apply()` method should return new `class_log_probabilities` that enforce the constraint |
| for this step of beam search. For instance, it may prevent a specific class from being selected by setting |
| the corresponding log probability to a negligible value such as `float("-inf")` or |
| `torch.finfo(class_log_probabilities.dtype).min`. |
| |
| `_update_state()` takes two arguments: |
| |
| - the copied parent constraint state, which is a nested list of dictionaries. `state[i][j]` contains the |
| copied state for the parent of `last_prediction[i, j]`. It is unique to that batch and beam, so it can be |
| directly edited in-place without affecting the others. |
| - last_prediction, a tensor of shape `(batch_size, beam_size)` containing the predictions from the last |
| step of beam search. |
| |
| The `_update_state()` function should return a new constraint state, a nested list of dictionaries of |
| length `batch_size` and inner list of length `beam_size`, one for each of the predictions in `last_prediction`. |
| |
| """ |
|
|
| @abstractmethod |
| def init_state( |
| self, |
| batch_size: int, |
| ) -> ConstraintStateType: |
| raise NotImplementedError |
|
|
| @abstractmethod |
| def apply( |
| self, |
| state: ConstraintStateType, |
| class_log_probabilities: torch.Tensor, |
| ) -> torch.Tensor: |
| raise NotImplementedError |
|
|
| @staticmethod |
| def _copy_state( |
| state: ConstraintStateType, |
| batch_size: int, |
| beam_size: int, |
| last_backpointer: Optional[torch.Tensor] = None, |
| ) -> ConstraintStateType: |
| """ |
| Copies the `state` . This method copies the data in `state` using `copy.deepcopy()`. If this |
| is not appropriate for your constraint, you will need to implement the copying yourself. |
| """ |
| new_state = [] |
| for i in range(batch_size): |
| batch_state = [] |
| for j in range(beam_size): |
| if last_backpointer is None: |
| |
| backpointer = 0 |
| else: |
| backpointer = last_backpointer[i, j].item() |
| batch_state.append(copy.deepcopy(state[i][backpointer])) |
| new_state.append(batch_state) |
| return new_state |
|
|
| def update_state( |
| self, |
| state: ConstraintStateType, |
| last_prediction: torch.Tensor, |
| last_backpointer: Optional[torch.Tensor] = None, |
| ) -> ConstraintStateType: |
| batch_size, beam_size = last_prediction.size() |
| new_state = self._copy_state(state, batch_size, beam_size, last_backpointer) |
| return self._update_state(new_state, last_prediction) |
|
|
| @abstractmethod |
| def _update_state( |
| self, |
| state: ConstraintStateType, |
| last_prediction: torch.Tensor, |
| ) -> ConstraintStateType: |
| raise NotImplementedError |
|
|
|
|
| class RepeatedNGramBlockingConstraint(Constraint): |
| def __init__(self, ngram_size: int, **kwargs) -> None: |
| super().__init__(**kwargs) |
| self.ngram_size = ngram_size |
|
|
| def init_state( |
| self, |
| batch_size: int, |
| ) -> ConstraintStateType: |
| return [[{"seen_ngrams": {}, "current_prefix": []}] for _ in range(batch_size)] |
|
|
| def apply( |
| self, |
| state: ConstraintStateType, |
| class_log_probabilities: torch.Tensor, |
| ) -> torch.Tensor: |
| for i, batch in enumerate(state): |
| for j, beam in enumerate(batch): |
| current_prefix = tuple(beam["current_prefix"]) |
| seen_ngrams = beam["seen_ngrams"] |
| try: |
| disallowed_indices = seen_ngrams[current_prefix] |
| class_log_probabilities[i, j, disallowed_indices] = torch.finfo( |
| class_log_probabilities.dtype |
| ).min |
| except KeyError: |
| |
| |
| pass |
| return class_log_probabilities |
|
|
| def _update_state( |
| self, |
| state: ConstraintStateType, |
| last_prediction: torch.Tensor, |
| ) -> ConstraintStateType: |
| for i, batch in enumerate(state): |
| for j, beam in enumerate(batch): |
| prediction = last_prediction[i, j].item() |
| prefix = beam["current_prefix"] |
| seen_ngrams = beam["seen_ngrams"] |
|
|
| if len(prefix) == self.ngram_size - 1: |
| |
| if tuple(prefix) not in seen_ngrams: |
| seen_ngrams[tuple(prefix)] = [] |
| seen_ngrams[tuple(prefix)].append(prediction) |
|
|
| |
| |
| prefix.append(prediction) |
| if len(prefix) == self.ngram_size: |
| prefix.pop(0) |
| return state |
|
|
|
|
| class BeamSearch: |
| """ |
| Implements the beam search algorithm for decoding the most likely sequences. |
| |
| :param end_index: The index of the "stop" or "end" token in the vocabulary. Usually the EOS token ID. |
| |
| :param max_steps: The maximum number of decoding steps to take, i.e. the maximum length |
| of the predicted sequences. |
| |
| :param beam_size: The width of the beam used. |
| |
| :param per_node_beam_size: The maximum number of candidates to consider per node, at each step in the search. |
| If not given, this just defaults to `beam_size`. Setting this parameter |
| to a number smaller than `beam_size` may give better results, as it can introduce |
| more diversity into the search. See |
| [*Beam Search Strategies for Neural Machine Translation*, Freitag and Al-Onaizan, 2017] |
| (https://api.semanticscholar.org/CorpusID:2229477). |
| |
| :param sampler: An optional `Sampler` which is used to pick next candidate nodes and beams. |
| If not specified, `DeterministicSampler` will be used, which just takes the |
| `per_node_beam_size` most likely nodes and the `beam_size` most likely beams. |
| |
| Using the [`GumbelSampler`](#gumbelsampler), on the other hand, will give you |
| [Stochastic Beam Search](https://api.semanticscholar.org/CorpusID:76662039). |
| |
| :param min_steps: The minimum number of decoding steps to take, i.e. the minimum length of |
| the predicted sequences. This does not include the start or end tokens. If `None`, |
| no minimum is enforced. |
| |
| :param final_sequence_scorer: An optional `FinalSequenceScorer` which is used to score the final generated sequences. |
| The output from this module is what is returned by the `search` method. If not |
| specified, `SequenceLogProbabilityScorer` will be used, which scores the sequences |
| by the sum of the token log probabilities. |
| |
| :param constraints: An optional list of `Constraint`s which should be applied during beam search. If not |
| provided, no constraints will be enforced. |
| |
| """ |
|
|
| def __init__( |
| self, |
| end_index: int, |
| *, |
| max_steps: int = 50, |
| beam_size: int = 10, |
| per_node_beam_size: Optional[int] = None, |
| sampler: Optional[Sampler] = None, |
| min_steps: Optional[int] = None, |
| final_sequence_scorer: Optional[FinalSequenceScorer] = None, |
| constraints: Optional[List[Constraint]] = None, |
| ) -> None: |
| if not max_steps > 0: |
| raise ValueError("max_steps must be positive") |
| if not beam_size > 0: |
| raise ValueError("beam_size must be positive") |
| if per_node_beam_size is not None and not per_node_beam_size > 0: |
| raise ValueError("per_node_beam_size must be positive") |
| if min_steps is not None: |
| if not min_steps >= 0: |
| raise ValueError("min_steps must be non-negative") |
| if not min_steps <= max_steps: |
| raise ValueError("min_steps must be less than or equal to max_steps") |
|
|
| self._end_index = end_index |
| self.max_steps = max_steps |
| self.beam_size = beam_size |
| self.per_node_beam_size = per_node_beam_size or beam_size |
| self.sampler = sampler or DeterministicSampler() |
| self.min_steps = min_steps or 0 |
| self.final_sequence_scorer = final_sequence_scorer or SequenceLogProbabilityScorer() |
| self.constraints = constraints or [] |
|
|
| @staticmethod |
| def _reconstruct_sequences(predictions, backpointers): |
| |
| |
| reconstructed_predictions = [predictions[-1].unsqueeze(2)] |
|
|
| if not backpointers: |
| return reconstructed_predictions |
|
|
| |
| cur_backpointers = backpointers[-1] |
|
|
| for timestep in range(len(predictions) - 2, 0, -1): |
| |
| cur_preds = predictions[timestep].gather(1, cur_backpointers).unsqueeze(2) |
|
|
| reconstructed_predictions.append(cur_preds) |
|
|
| |
| cur_backpointers = backpointers[timestep - 1].gather(1, cur_backpointers) |
|
|
| |
| final_preds = predictions[0].gather(1, cur_backpointers).unsqueeze(2) |
|
|
| reconstructed_predictions.append(final_preds) |
|
|
| return reconstructed_predictions |
|
|
| def search( |
| self, |
| start_predictions: torch.Tensor, |
| start_state: StateType, |
| step: StepFunctionType, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Given a starting state and a step function, apply beam search to find the |
| most likely target sequences. |
| |
| Returns a tuple of `(predictions, final_scores)`, where `predictions` |
| has shape `(batch_size, beam_size, max_steps)` and `final_scores` |
| has shape `(batch_size, beam_size)`. |
| |
| .. note:: |
| If your step function returns `-inf` for some log probabilities |
| (like if you're using a masked log-softmax) then some of the "best" |
| sequences returned may also have `-inf` log probability. Specifically |
| this happens when the beam size is smaller than the number of actions |
| with finite log probability (non-zero probability) returned by the step function. |
| Therefore if you're using a mask you may want to check the results from `search` |
| and potentially discard sequences with non-finite log probability. |
| |
| :param start_predictions: A tensor containing the initial predictions with shape `(batch_size,)`. |
| Usually the initial predictions are just the index of the "start" token |
| in the target vocabulary. |
| |
| :param start_state: The initial state passed to the `step` function. Each value of the state dict |
| should be a tensor of shape `(batch_size, *)`, where `*` means any other |
| number of dimensions. |
| |
| :param step: A function that is responsible for computing the next most likely tokens, |
| given the current state and the predictions from the last time step. |
| The function should accept two or three arguments: |
| |
| - a tensor of shape `(group_size,)` or representing the index of the predicted |
| tokens from the last time step, |
| - the current state, a `StateType`, and |
| - optionally, the timestep, an `int`. |
| |
| The `group_size` will be `batch_size * beam_size`, except in the initial |
| step, for which it will just be `batch_size`. |
| |
| The function is expected to return a tuple, where the first element |
| is a tensor of shape `(group_size, vocab_size)` containing |
| the log probabilities of the tokens for the next step, and the second |
| element is the updated state. The tensor in the state should have shape |
| `(group_size, *)`, where `*` means any other number of dimensions. |
| |
| """ |
| step_signature = signature(step) |
| if len(step_signature.parameters) < 3: |
| |
| |
| old_step = cast(StepFunctionTypeNoTimestep, step) |
|
|
| def new_step(last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], time_step: int): |
| del time_step |
| return old_step(last_predictions, state) |
|
|
| return self._search(start_predictions, start_state, new_step) |
| else: |
| return self._search(start_predictions, start_state, cast(StepFunctionTypeWithTimestep, step)) |
|
|
| def _search( |
| self, |
| start_predictions: torch.Tensor, |
| start_state: StateType, |
| step: StepFunctionTypeWithTimestep, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| batch_size = start_predictions.size()[0] |
|
|
| |
| |
| predictions: List[torch.Tensor] = [] |
|
|
| |
| |
| |
| backpointers: List[torch.Tensor] = [] |
|
|
| constraint_states = [constraint.init_state(batch_size) for constraint in self.constraints] |
|
|
| |
| |
| |
| |
| |
| |
| |
| start_class_log_probabilities, state = step(start_predictions, start_state, 0) |
|
|
| num_classes = start_class_log_probabilities.size()[1] |
|
|
| |
| if self.per_node_beam_size > num_classes: |
| raise ValueError( |
| f"Vocab size ({num_classes:d}) too small " |
| f"relative to per_node_beam_size ({self.per_node_beam_size:d}).\n" |
| f"Please decrease beam_size or per_node_beam_size." |
| ) |
|
|
| sampler_state = self.sampler.init_state(start_class_log_probabilities, batch_size, num_classes) |
|
|
| |
| if self.constraints: |
| |
| expanded_start_class_log_probabilities = start_class_log_probabilities.unsqueeze(1) |
| for constraint, constraint_state in zip(self.constraints, constraint_states): |
| expanded_start_class_log_probabilities = constraint.apply( |
| constraint_state, expanded_start_class_log_probabilities |
| ) |
| start_class_log_probabilities = expanded_start_class_log_probabilities.squeeze(1) |
|
|
| |
| if self.min_steps >= 1: |
| start_class_log_probabilities[:, self._end_index] = torch.finfo( |
| start_class_log_probabilities.dtype |
| ).min |
|
|
| |
| |
| ( |
| start_top_log_probabilities, |
| start_predicted_classes, |
| sampler_state, |
| ) = self.sampler.sample_beams(start_class_log_probabilities, self.beam_size, sampler_state) |
|
|
| if self.beam_size == 1 and (start_predicted_classes == self._end_index).all(): |
| warnings.warn( |
| "Empty sequences predicted. You may want to increase the beam size or ensure " |
| "your step function is working properly.", |
| RuntimeWarning, |
| ) |
| return start_predicted_classes.unsqueeze(-1), start_top_log_probabilities |
|
|
| |
| |
| last_log_probabilities = start_top_log_probabilities |
|
|
| |
| predictions.append(start_predicted_classes) |
|
|
| |
| |
| log_probs_after_end = start_class_log_probabilities.new_full( |
| (batch_size * self.beam_size, num_classes), |
| torch.finfo(start_class_log_probabilities.dtype).min, |
| ) |
| log_probs_after_end[:, self._end_index] = 0.0 |
|
|
| |
| self._update_initial_state(state, batch_size) |
|
|
| for i, constraint in enumerate(self.constraints): |
| constraint_states[i] = constraint.update_state(constraint_states[i], start_predicted_classes) |
|
|
| for timestep in range(self.max_steps - 1): |
| |
| last_predictions = predictions[-1].reshape(batch_size * self.beam_size) |
|
|
| |
| |
| if (last_predictions == self._end_index).all(): |
| break |
| |
| |
| |
| class_log_probabilities, state = step(last_predictions, state, timestep + 1) |
|
|
| |
| if self.constraints: |
| |
| reshaped_class_log_probabilities = class_log_probabilities.view(batch_size, self.beam_size, -1) |
| for constraint, constraint_state in zip(self.constraints, constraint_states): |
| reshaped_class_log_probabilities = constraint.apply( |
| constraint_state, reshaped_class_log_probabilities |
| ) |
| |
| class_log_probabilities = reshaped_class_log_probabilities.view(batch_size * self.beam_size, -1) |
|
|
| |
| |
| |
| |
| if timestep + 2 <= self.min_steps: |
| class_log_probabilities[:, self._end_index] = torch.finfo(class_log_probabilities.dtype).min |
|
|
| |
| last_predictions_expanded = last_predictions.unsqueeze(-1).expand( |
| batch_size * self.beam_size, num_classes |
| ) |
|
|
| |
| |
| |
| |
| |
| cleaned_log_probabilities = torch.where( |
| last_predictions_expanded == self._end_index, |
| log_probs_after_end, |
| class_log_probabilities, |
| ) |
|
|
| |
| top_log_probabilities, predicted_classes, sampler_state = self.sampler.sample_nodes( |
| cleaned_log_probabilities, self.per_node_beam_size, sampler_state |
| ) |
|
|
| |
| |
| |
| |
| expanded_last_log_probabilities = ( |
| last_log_probabilities.unsqueeze(2) |
| .expand(batch_size, self.beam_size, self.per_node_beam_size) |
| .reshape(batch_size * self.beam_size, self.per_node_beam_size) |
| ) |
|
|
| |
| summed_top_log_probabilities = top_log_probabilities + expanded_last_log_probabilities |
|
|
| |
| reshaped_summed = summed_top_log_probabilities.reshape( |
| batch_size, self.beam_size * self.per_node_beam_size |
| ) |
|
|
| |
| reshaped_predicted_classes = predicted_classes.reshape( |
| batch_size, self.beam_size * self.per_node_beam_size |
| ) |
|
|
| |
| |
| ( |
| restricted_beam_log_probs, |
| restricted_beam_indices, |
| sampler_state, |
| ) = self.sampler.sample_beams(reshaped_summed, self.beam_size, sampler_state) |
|
|
| |
| |
| restricted_predicted_classes = reshaped_predicted_classes.gather(1, restricted_beam_indices) |
|
|
| predictions.append(restricted_predicted_classes) |
|
|
| |
| last_log_probabilities = restricted_beam_log_probs |
|
|
| |
| |
| |
| |
| |
| backpointer = torch.divide(restricted_beam_indices, self.per_node_beam_size, rounding_mode="trunc") |
| backpointers.append(backpointer) |
|
|
| |
| |
| self._update_state(state, backpointer) |
|
|
| for i, constraint in enumerate(self.constraints): |
| constraint_states[i] = constraint.update_state( |
| constraint_states[i], restricted_predicted_classes, last_backpointer=backpointer |
| ) |
|
|
| |
| |
| if not self.constraints and ( |
| not torch.isfinite(last_log_probabilities).all() |
| or (last_log_probabilities == torch.finfo(last_log_probabilities.dtype).min).any() |
| ): |
| warnings.warn( |
| "Negligible log probabilities encountered ('-inf' or equivalent). " |
| "Some final sequences may not make sense. " |
| "This can happen when the beam size is larger than the number of valid (non-zero " |
| "probability) transitions that the step function produces.", |
| RuntimeWarning, |
| ) |
|
|
| reconstructed_predictions = self._reconstruct_sequences(predictions, backpointers) |
|
|
| |
| all_predictions = torch.cat(list(reversed(reconstructed_predictions)), 2) |
|
|
| |
| |
| final_scores = self.final_sequence_scorer.score(all_predictions, last_log_probabilities, self._end_index) |
|
|
| |
| |
| sorted_final_scores, sorted_indices = torch.sort(final_scores, dim=1, descending=True) |
| sorted_all_predictions = torch.gather( |
| all_predictions, 1, sorted_indices.unsqueeze(-1).expand_as(all_predictions) |
| ) |
|
|
| return sorted_all_predictions, sorted_final_scores |
|
|
| def _update_initial_state(self, state: StateType, batch_size: int): |
| """ |
| Expand tensors in a state dictionary from `(batch_size, *)` to `(batch_size * beam_size, *)`. |
| """ |
| for key, state_tensor in state.items(): |
| if state_tensor is None: |
| continue |
| |
| _, *last_dims = state_tensor.size() |
| state[key] = ( |
| state_tensor.unsqueeze(1) |
| .expand(batch_size, self.beam_size, *last_dims) |
| .reshape(batch_size * self.beam_size, *last_dims) |
| ) |
|
|
| def _update_state(self, state: StateType, backpointer: torch.Tensor): |
| batch_size = backpointer.size()[0] |
|
|
| for key, state_tensor in state.items(): |
| if state_tensor is None: |
| continue |
| _, *last_dims = state_tensor.size() |
| |
| expanded_backpointer = backpointer.view(batch_size, self.beam_size, *([1] * len(last_dims))).expand( |
| batch_size, self.beam_size, *last_dims |
| ) |
| |
| state[key] = ( |
| state_tensor.reshape(batch_size, self.beam_size, *last_dims) |
| .gather(1, expanded_backpointer) |
| .reshape(batch_size * self.beam_size, *last_dims) |
| ) |
|
|