Spaces:
Runtime error
Runtime error
| r""" | |
| This Beam Search implementation is adapted with minor modifications from | |
| `AllenNLP <https://github.com/allenai/allennlp/blob/master/allennlp/nn/beam_search.py>`_. | |
| Thanks to the developers of AllenNLP! | |
| """ | |
| from typing import Callable, List, Tuple | |
| import warnings | |
| import torch | |
| from torch.nn import functional as F | |
| class AutoRegressiveBeamSearch(object): | |
| r""" | |
| Implements the beam search algorithm for decoding the most likely captions. | |
| This only works for auto-regressive models (Transformer-like) and not | |
| recurrent models (LSTM-like). | |
| Parameters | |
| ---------- | |
| eos_index: int | |
| The index of the end token (``[EOS]``) in vocabulary. | |
| max_steps: int, optional (default = 50) | |
| The maximum number of decoding steps. | |
| beam_size: int, optional (default = 5) | |
| The width of the beam used. | |
| per_node_beam_size: int, optional (default = 2) | |
| The maximum number of candidates to consider per node, at each step in | |
| the search. Setting this parameter to a number smaller than `beam_size` | |
| may give better results, as it can introduce more diversity into the | |
| search. See `Beam Search Strategies for Neural Machine Translation. | |
| Freitag and Al-Onaizan, 2017 <https://arxiv.org/abs/1702.01806>`_. | |
| """ | |
| def __init__( | |
| self, | |
| eos_index: int, | |
| max_steps: int = 50, | |
| beam_size: int = 5, | |
| per_node_beam_size: int = 2, | |
| ): | |
| self._eos_index = eos_index | |
| self.max_steps = max_steps | |
| self.beam_size = beam_size | |
| self.per_node_beam_size = per_node_beam_size or beam_size | |
| def search( | |
| self, start_predictions: torch.Tensor, step: Callable[..., torch.Tensor] | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| r""" | |
| Given a starting state and a step function, apply beam search to find | |
| the most likely target captions. | |
| Parameters | |
| ---------- | |
| start_predictions : torch.Tensor | |
| Tensor containing the initial predictions, shape ``(batch_size, )``. | |
| Usually the initial predictions are just the index of the start | |
| token (``[SOS]``) in the vocabulary. | |
| step : Callable[..., torch.Tensor] | |
| A function that is responsible for computing the next most likely | |
| tokens, given the past predictions. Predictions from all previous | |
| timesteps are required, not just the last timestep, because our | |
| model is auto-regressive instead of recurrent. The function should | |
| The function is expected to return a tensor of shape | |
| ``(group_size, target_vocab_size)`` containing | |
| the logits of the tokens for the next step. | |
| Returns | |
| ------- | |
| Tuple[torch.Tensor, torch.Tensor] | |
| Tuple of ``(predictions, logprobs)``, where ``predictions`` | |
| has shape ``(batch_size, beam_size, max_steps)`` and ``logprobs`` | |
| has shape ``(batch_size, beam_size)``. | |
| """ | |
| batch_size = start_predictions.size()[0] | |
| # List of `(batch_size, beam_size)` tensors. One for each time step. | |
| # Does not include the start symbols, which are implicit. | |
| predictions: List[torch.Tensor] = [] | |
| # List of (batch_size, beam_size) tensors. One for each time step. None | |
| # for the first. Stores the index n for the parent prediction, i.e. | |
| # predictions[t-1][i][n], that it came from. | |
| backpointers: List[torch.Tensor] = [] | |
| # Calculate the first timestep. This is done outside the main loop | |
| # because we are going from a single decoder input (the output from the | |
| # encoder) to the top `beam_size` decoder outputs. On the other hand, | |
| # within the main loop we are going from the `beam_size` elements of the | |
| # beam to `beam_size`^2 candidates from which we will select the top | |
| # `beam_size` elements for the next iteration. | |
| # shape: (batch_size, num_classes) | |
| start_class_logits = step(start_predictions) | |
| # Convert logits to logprobs. | |
| # shape: (batch_size * beam_size, vocab_size) | |
| start_class_logprobs = F.log_softmax(start_class_logits, dim=1) | |
| num_classes = start_class_logprobs.size()[1] | |
| # Make sure `per_node_beam_size` is not larger than `num_classes`. | |
| if self.per_node_beam_size > num_classes: | |
| raise ValueError( | |
| f"Target vocab size ({num_classes:d}) too small " | |
| f"relative to per_node_beam_size ({self.per_node_beam_size:d}).\n" | |
| f"Please decrease beam_size or per_node_beam_size." | |
| ) | |
| # shape: (batch_size, beam_size), (batch_size, beam_size) | |
| start_top_logprobs, start_predicted_classes = start_class_logprobs.topk( | |
| self.beam_size | |
| ) | |
| if ( | |
| self.beam_size == 1 | |
| and (start_predicted_classes == self._eos_index).all() | |
| ): | |
| warnings.warn( | |
| "Empty captions predicted. You may want to increase beam " | |
| "size or ensure your step function is working properly.", | |
| RuntimeWarning, | |
| ) | |
| return start_predicted_classes.unsqueeze(-1), start_top_logprobs | |
| # The log probs for the last time step. | |
| # shape: (batch_size, beam_size) | |
| last_logprobs = start_top_logprobs | |
| # shape: [(batch_size, beam_size)] | |
| predictions.append(start_predicted_classes) | |
| # Log probability tensor that mandates that the end token is selected. | |
| # shape: (batch_size * beam_size, num_classes) | |
| logprobs_after_end = start_class_logprobs.new_full( | |
| (batch_size * self.beam_size, num_classes), float("-inf") | |
| ) | |
| logprobs_after_end[:, self._eos_index] = 0.0 | |
| for timestep in range(self.max_steps - 1): | |
| # shape: (batch_size * beam_size,) | |
| last_predictions = predictions[-1].reshape(batch_size * self.beam_size) | |
| # If every predicted token from the last step is `self._eos_index`, | |
| # then we can stop early. | |
| if (last_predictions == self._eos_index).all(): | |
| break | |
| # Take a step. This get the predicted log probs of the next classes. | |
| predictions_so_far = torch.stack(predictions).permute(1, 2, 0).view( | |
| batch_size * self.beam_size, -1 | |
| ) | |
| # shape: (batch_size * beam_size, num_classes) | |
| class_logits = step(predictions_so_far) | |
| # Convert logits to logprobs. | |
| # shape: (batch_size * beam_size, vocab_size) | |
| class_logprobs = F.log_softmax(class_logits, dim=1) | |
| # Set logprobs of last predicted tokens as high negative value to avoid | |
| # repetition in caption. | |
| for index in range(batch_size * self.beam_size): | |
| class_logprobs[index, predictions_so_far[index, -1]] = -10000 | |
| # shape: (batch_size * beam_size, num_classes) | |
| last_predictions_expanded = last_predictions.unsqueeze(-1).expand( | |
| batch_size * self.beam_size, num_classes | |
| ) | |
| # Here we are finding any beams where we predicted the end token in | |
| # the previous timestep and replacing the distribution with a | |
| # one-hot distribution, forcing the beam to predict the end token | |
| # this timestep as well. | |
| # shape: (batch_size * beam_size, num_classes) | |
| cleaned_logprobs = torch.where( | |
| last_predictions_expanded == self._eos_index, | |
| logprobs_after_end, | |
| class_logprobs, | |
| ) | |
| # shape (both): (batch_size * beam_size, per_node_beam_size) | |
| top_logprobs, predicted_classes = cleaned_logprobs.topk( | |
| self.per_node_beam_size | |
| ) | |
| # Here we expand the last log probs to `(batch_size * beam_size, | |
| # per_node_beam_size)` so that we can add them to the current log | |
| # probs for this timestep. This lets us maintain the log | |
| # probability of each element on the beam. | |
| # shape: (batch_size * beam_size, per_node_beam_size) | |
| expanded_last_logprobs = ( | |
| last_logprobs.unsqueeze(2) | |
| .expand(batch_size, self.beam_size, self.per_node_beam_size) | |
| .reshape(batch_size * self.beam_size, self.per_node_beam_size) | |
| ) | |
| # shape: (batch_size * beam_size, per_node_beam_size) | |
| summed_top_logprobs = top_logprobs + expanded_last_logprobs | |
| # shape: (batch_size, beam_size * per_node_beam_size) | |
| reshaped_summed = summed_top_logprobs.reshape( | |
| batch_size, self.beam_size * self.per_node_beam_size | |
| ) | |
| # shape: (batch_size, beam_size * per_node_beam_size) | |
| reshaped_predicted_classes = predicted_classes.reshape( | |
| batch_size, self.beam_size * self.per_node_beam_size | |
| ) | |
| # Keep only the top `beam_size` beam indices. | |
| # shape: (batch_size, beam_size), (batch_size, beam_size) | |
| restricted_beam_logprobs, restricted_beam_indices = reshaped_summed.topk( | |
| self.beam_size | |
| ) | |
| # Use the beam indices to extract the corresponding classes. | |
| # shape: (batch_size, beam_size) | |
| restricted_predicted_classes = reshaped_predicted_classes.gather( | |
| 1, restricted_beam_indices | |
| ) | |
| predictions.append(restricted_predicted_classes) | |
| # shape: (batch_size, beam_size) | |
| last_logprobs = restricted_beam_logprobs | |
| # The beam indices come from a `beam_size * per_node_beam_size` | |
| # dimension where the indices with a common ancestor are grouped | |
| # together. Hence dividing by `per_node_beam_size` gives the | |
| # ancestor. (Note that this is integer division as the tensor is a | |
| # LongTensor.) | |
| # shape: (batch_size, beam_size) | |
| backpointer = restricted_beam_indices // self.per_node_beam_size | |
| backpointers.append(backpointer) | |
| if not torch.isfinite(last_logprobs).all(): | |
| warnings.warn( | |
| "Infinite log probs encountered. Some final captions may not " | |
| "make sense. This can happen when the beam size is larger than" | |
| " the number of valid (non-zero probability) transitions that " | |
| "the step function produces.", | |
| RuntimeWarning, | |
| ) | |
| # Reconstruct the captions. | |
| # shape: [(batch_size, beam_size, 1)] | |
| reconstructed_predictions = [predictions[-1].unsqueeze(2)] | |
| # shape: (batch_size, beam_size) | |
| cur_backpointers = backpointers[-1] | |
| for timestep in range(len(predictions) - 2, 0, -1): | |
| # shape: (batch_size, beam_size, 1) | |
| cur_preds = ( | |
| predictions[timestep].gather(1, cur_backpointers).unsqueeze(2) | |
| ) | |
| reconstructed_predictions.append(cur_preds) | |
| # shape: (batch_size, beam_size) | |
| cur_backpointers = backpointers[timestep - 1].gather(1, cur_backpointers) | |
| # shape: (batch_size, beam_size, 1) | |
| final_preds = predictions[0].gather(1, cur_backpointers).unsqueeze(2) | |
| reconstructed_predictions.append(final_preds) | |
| # shape: (batch_size, beam_size, max_steps) | |
| all_predictions = torch.cat(list(reversed(reconstructed_predictions)), 2) | |
| # Select the top-beam and its logprobs. | |
| all_predictions = all_predictions[:, 0, :] | |
| last_logprobs = last_logprobs[:, 0] | |
| return all_predictions, last_logprobs | |