Spaces:
Runtime error
Runtime error
| import copy | |
| import functools | |
| from typing import Any, Dict | |
| import torch | |
| from torch import nn | |
| from virtex.data.tokenizers import SentencePieceBPETokenizer | |
| from virtex.modules.label_smoothing import CrossEntropyLossWithLabelSmoothing | |
| from virtex.modules.textual_heads import TextualHead | |
| from virtex.modules.visual_backbones import VisualBackbone | |
| class CaptioningModel(nn.Module): | |
| r""" | |
| A model to perform image captioning (in both forward and backward directions | |
| independently, only in forward direction). It is composed of a | |
| :class:`~virtex.modules.visual_backbones.VisualBackbone` and a | |
| :class:`~virtex.modules.textual_heads.TextualHead` on top of it. | |
| During training, it maximizes the likelihood of ground truth caption | |
| conditioned on image features. During inference, it predicts a caption for | |
| an input image through beam search decoding. | |
| Parameters | |
| ---------- | |
| visual: virtex.modules.visual_backbones.VisualBackbone | |
| A :class:`~virtex.modules.visual_backbones.VisualBackbone` which | |
| computes visual features from an input image. | |
| textual: virtex.modules.textual_heads.TextualHead | |
| A :class:`~virtex.modules.textual_heads.TextualHead` which | |
| makes final predictions conditioned on visual features. | |
| sos_index: int, optional (default = 1) | |
| The index of the end token (``[SOS]``) in vocabulary. | |
| eos_index: int, optional (default = 2) | |
| The index of the end token (``[EOS]``) in vocabulary. | |
| caption_backward: bool, optional (default = False) | |
| Whether to *also* perform captioning in backward direction. Default is | |
| ``False`` -- only forward captioning is performed. When ``True``, a | |
| clone of textual head is created, which does not share weights with | |
| "forward" model except input and output embeddings. | |
| decoder: Any, optional (default = None) | |
| An instance of :class:`~virtex.utils.beam_search.AutoRegressiveBeamSearch` | |
| or :class:`~virtex.utils.nucleus_sampling.AutoRegressiveNucleusSampling` | |
| for decoding captions during inference (unused during training). | |
| """ | |
| def __init__( | |
| self, | |
| visual: VisualBackbone, | |
| textual: TextualHead, | |
| caption_backward: bool = False, | |
| sos_index: int = 1, | |
| eos_index: int = 2, | |
| label_smoothing: float = 0.0, | |
| decoder: Any = None, | |
| ): | |
| super().__init__() | |
| self.visual = visual | |
| self.textual = textual | |
| self.padding_idx = self.textual.padding_idx | |
| self.caption_backward = caption_backward | |
| # Clone the textual module for backward direction if doing captioning | |
| # in both directions (separately). | |
| if self.caption_backward: | |
| self.backward_textual = copy.deepcopy(self.textual) | |
| # Share weights for visual projection, and input/output embeddings. | |
| self.backward_textual.visual_projection = self.textual.visual_projection | |
| self.backward_textual.embedding = self.textual.embedding | |
| self.backward_textual.output = self.textual.output | |
| # These boundary indices are needed for beam search. | |
| self.sos_index = sos_index | |
| self.eos_index = eos_index | |
| self.decoder = decoder | |
| self.loss = CrossEntropyLossWithLabelSmoothing( | |
| label_smoothing, ignore_index=self.padding_idx | |
| ) | |
| def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, Any]: | |
| r""" | |
| Given a batch of images and captions, compute log likelihood loss per | |
| caption token during training. During inference (with images), predict | |
| a caption through either beam search decoding or nucleus sampling. | |
| Parameters | |
| ---------- | |
| batch: Dict[str, torch.Tensor] | |
| Training or inference batch. During training, a batch would at least | |
| contain keys ``{"image", "caption_tokens", "caption_lengths"}`` and | |
| also ``"noitpac_tokens"`` for bicaptioning. | |
| During inference, a batch would contain key ``{"image"}`` and | |
| optionally ``"decode_prompt"`` as a partial sequence for decoding. | |
| Returns | |
| ------- | |
| Dict[str, Any] | |
| A dict with the following structure, containing loss for optimization, | |
| loss components to log directly to tensorboard, and optionally | |
| predictions. | |
| .. code-block:: | |
| { | |
| "loss": torch.Tensor, | |
| "loss_components": { | |
| "captioning_forward": torch.Tensor, | |
| "captioning_backward": torch.Tensor, (optional) | |
| }, | |
| "predictions": torch.Tensor | |
| } | |
| """ | |
| # shape: (batch_size, channels, height, width) | |
| visual_features = self.visual(batch["image"]) | |
| batch_size = visual_features.size(0) | |
| if "caption_tokens" in batch: | |
| caption_tokens = batch["caption_tokens"] | |
| caption_lengths = batch["caption_lengths"] | |
| # shape: (batch_size, max_caption_length, vocab_size) | |
| output_logits = self.textual( | |
| visual_features, caption_tokens, caption_lengths | |
| ) | |
| loss = self.loss( | |
| output_logits[:, :-1].contiguous().view(-1, self.textual.vocab_size), | |
| caption_tokens[:, 1:].contiguous().view(-1), | |
| ) | |
| output_dict: Dict[str, Any] = { | |
| "loss": loss, | |
| # Single scalar per batch for logging in training script. | |
| "loss_components": {"captioning_forward": loss.clone().detach()}, | |
| } | |
| # Do captioning in backward direction if specified. | |
| if self.caption_backward: | |
| backward_caption_tokens = batch["noitpac_tokens"] | |
| backward_output_logits = self.backward_textual( | |
| visual_features, backward_caption_tokens, caption_lengths | |
| ) | |
| backward_loss = self.loss( | |
| backward_output_logits[:, :-1] | |
| .contiguous() | |
| .view(-1, self.textual.vocab_size), | |
| backward_caption_tokens[:, 1:].contiguous().view(-1), | |
| ) | |
| output_dict["loss"] += backward_loss | |
| # Single scalar per batch for logging in training script. | |
| output_dict["loss_components"].update( | |
| captioning_backward=backward_loss.clone().detach() | |
| ) | |
| if not self.training: | |
| # During validation (while pretraining), get best prediction | |
| # at every timestep. | |
| output_dict["predictions"] = torch.argmax(output_logits, dim=-1) | |
| else: | |
| if self.decoder is None: | |
| raise ValueError("Decoder for predicting captions is missing!") | |
| # During inference, decode captions from forward transformer model. | |
| # Check if the batch contains decoding prompt. | |
| if "decode_prompt" in batch: | |
| # shape: (batch_size, prompt_length) | |
| start_predictions = torch.unsqueeze(batch["decode_prompt"], 0) | |
| start_predictions = start_predictions.repeat(batch_size, 1) | |
| else: | |
| # shape: (batch_size, ) | |
| start_predictions = torch.full( | |
| (batch_size,), self.sos_index, device=visual_features.device | |
| ).long() | |
| # Add image features as a default argument to match callable | |
| # signature accepted by beam search class (partial captions only). | |
| decoding_step = functools.partial(self.decoding_step, visual_features) | |
| predicted_caption, _ = self.decoder.search( | |
| start_predictions, decoding_step | |
| ) | |
| output_dict = {"predictions": predicted_caption} | |
| return output_dict | |
| def decoding_step( | |
| self, visual_features: torch.Tensor, partial_captions: torch.Tensor | |
| ) -> torch.Tensor: | |
| r""" | |
| Given visual features and a batch of (assumed) partial captions, predict | |
| the logits over output vocabulary tokens for next timestep. This method | |
| is used by :class:`~virtex.utils.beam_search.AutoRegressiveBeamSearch` | |
| and :class:`~virtex.utils.nucleus_sampling.AutoRegressiveNucleusSampling`. | |
| .. note:: | |
| For nucleus sampling, ``beam_size`` will always be 1 (not relevant). | |
| Parameters | |
| ---------- | |
| projected_visual_features: torch.Tensor | |
| A tensor of shape ``(batch_size, ..., textual_feature_size)`` | |
| with visual features already projected to ``textual_feature_size``. | |
| partial_captions: torch.Tensor | |
| A tensor of shape ``(batch_size * beam_size, timesteps)`` | |
| containing tokens predicted so far -- one for each beam. We need all | |
| prior predictions because our model is auto-regressive. | |
| Returns | |
| ------- | |
| torch.Tensor | |
| A tensor of shape ``(batch_size * beam_size, vocab_size)`` -- logits | |
| over output vocabulary tokens for next timestep. | |
| """ | |
| # Expand and repeat image features while doing beam search. | |
| batch_size, channels, height, width = visual_features.size() | |
| beam_size = int(partial_captions.size(0) / batch_size) | |
| if beam_size > 1: | |
| # shape: (batch_size * beam_size, channels, height, width) | |
| visual_features = visual_features.unsqueeze(1).repeat(1, beam_size, 1, 1, 1) | |
| visual_features = visual_features.view( | |
| batch_size * beam_size, channels, height, width | |
| ) | |
| # Provide caption lengths as current length (irrespective of predicted | |
| # EOS/padding tokens). shape: (batch_size, ) | |
| caption_lengths = torch.ones_like(partial_captions) | |
| if len(caption_lengths.size()) == 2: | |
| caption_lengths = caption_lengths.sum(1) | |
| else: | |
| # Add a timestep. shape: (batch_size, 1) | |
| partial_captions = partial_captions.unsqueeze(1) | |
| # shape: (batch_size * beam_size, partial_caption_length, vocab_size) | |
| logits = self.textual(visual_features, partial_captions, caption_lengths) | |
| # Return logits from the last timestep. | |
| return logits[:, -1, :] | |
| def log_predictions( | |
| self, batch: Dict[str, torch.Tensor], tokenizer: SentencePieceBPETokenizer | |
| ) -> str: | |
| self.eval() | |
| with torch.no_grad(): | |
| predictions = self.forward(batch)["predictions"] | |
| self.train() | |
| predictions_str = "" | |
| for tokens, preds in zip(batch["caption_tokens"], predictions): | |
| predictions_str += f""" | |
| Caption tokens : {" ".join(tokens.tolist())} | |
| Predictions (f): {" ".join(preds.tolist())} | |
| """ | |
| return predictions_str | |
| class ForwardCaptioningModel(CaptioningModel): | |
| r""" | |
| Convenient extension of :class:`~virtex.models.captioning.CaptioningModel` | |
| for better readability: this passes ``caption_backward=False`` to super class. | |
| """ | |
| def __init__( | |
| self, | |
| visual: VisualBackbone, | |
| textual: TextualHead, | |
| sos_index: int = 1, | |
| eos_index: int = 2, | |
| label_smoothing: float = 0.0, | |
| decoder: Any = None, | |
| ): | |
| super().__init__( | |
| visual, | |
| textual, | |
| sos_index=sos_index, | |
| eos_index=eos_index, | |
| caption_backward=False, | |
| label_smoothing=label_smoothing, | |
| decoder=decoder, | |
| ) | |
| class BidirectionalCaptioningModel(CaptioningModel): | |
| r""" | |
| Convenient extension of :class:`~virtex.models.captioning.CaptioningModel` | |
| for better readability: this passes ``caption_backward=True`` to super class. | |
| """ | |
| def __init__( | |
| self, | |
| visual: VisualBackbone, | |
| textual: TextualHead, | |
| sos_index: int = 1, | |
| eos_index: int = 2, | |
| label_smoothing: float = 0.0, | |
| decoder: Any = None, | |
| ): | |
| super().__init__( | |
| visual, | |
| textual, | |
| sos_index=sos_index, | |
| eos_index=eos_index, | |
| caption_backward=True, | |
| label_smoothing=label_smoothing, | |
| decoder=decoder, | |
| ) | |
| # Convenient handle for our main model. | |
| VirTexModel = BidirectionalCaptioningModel | |