Spaces:
Runtime error
Runtime error
| from typing import Any, Dict, List | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from virtex.data.tokenizers import SentencePieceBPETokenizer | |
| from virtex.modules.textual_heads import TextualHead | |
| from virtex.modules.visual_backbones import VisualBackbone | |
| class ClassificationModel(nn.Module): | |
| r""" | |
| A model to perform classification (generally, with multiple targets). It is | |
| composed of a :class:`~virtex.modules.visual_backbones.VisualBackbone` and a | |
| :class:`~virtex.modules.textual_heads.TextualHead` on top of it. | |
| .. note:: | |
| As with currently available textual heads, only one textual head is | |
| supported here: :class:`~virtex.modules.textual_heads.LinearTextualHead`. | |
| During training, it minimizes the KL-divergence loss with a K-hot vector, | |
| with values ``1/K``, where K are the number of unique labels to classify. | |
| Parameters | |
| ---------- | |
| visual: virtex.modules.visual_backbones.VisualBackbone | |
| A :class:`~virtex.modules.visual_backbones.VisualBackbone` which | |
| computes visual features from an input image. | |
| textual: virtex.modules.textual_heads.TextualHead | |
| A :class:`~virtex.modules.textual_heads.TextualHead` which | |
| makes final predictions conditioned on visual features. | |
| ignore_indices: List[int] | |
| Ignore a set of token indices while computing KL-divergence loss. These | |
| are usually the special tokens such as ``[SOS]``, ``[EOS]`` etc. | |
| """ | |
| def __init__( | |
| self, visual: VisualBackbone, textual: TextualHead, ignore_indices: List[int] | |
| ): | |
| super().__init__() | |
| self.visual = visual | |
| self.textual = textual | |
| self.ignore_indices = ignore_indices | |
| def forward(self, batch: Dict[str, torch.Tensor]): | |
| r""" | |
| Given a batch of images and set of labels, perform classification with | |
| multiple targets by minimizing a KL-divergence loss. | |
| Parameters | |
| ---------- | |
| batch: Dict[str, torch.Tensor] | |
| A batch of images and labels. Possible set of keys: | |
| ``{"image_id", "image", "labels"}`` | |
| Returns | |
| ------- | |
| Dict[str, Any] | |
| A dict with the following structure, containing loss for optimization, | |
| loss components to log directly to tensorboard, and optionally | |
| predictions. | |
| .. code-block:: | |
| { | |
| "loss": torch.Tensor, | |
| "loss_components": { | |
| "classification": torch.Tensor, | |
| }, | |
| "predictions": torch.Tensor | |
| } | |
| """ | |
| # shape: (batch_size, visual_feature_size, ...) | |
| visual_features = self.visual(batch["image"]) | |
| batch_size = visual_features.size(0) | |
| # Get logits and further log-probabilities. | |
| # shape: (batch_size, vocab_size) | |
| logits = self.textual(visual_features) | |
| logprobs = F.log_softmax(logits, dim=1) | |
| # Average log-probs per unique token in associated caption to compute | |
| # loss. This is simply cross-entropy with target-vector as a K-hot | |
| # vector. Do in a for-loop, there isn't a straightforward vectorized way. | |
| loss = torch.tensor(0.0, device=logprobs.device) | |
| for index in range(batch_size): | |
| # Get unique labels for particular instance. | |
| unique_labels = batch["labels"][index].unique() | |
| # Ignore indices of special tokens such as [SOS], [EOS] etc. and | |
| # any other token specified. | |
| unique_labels = [l for l in unique_labels if l not in self.ignore_indices] | |
| # Get log-probabilities corresponding to these tokens. | |
| instance_logprobs = logprobs[index, unique_labels].mean() | |
| # Accumulate negative log-probability for this instance in loss. | |
| loss = loss - instance_logprobs | |
| # Average loss across instances. | |
| output_dict: Dict[str, Any] = {"loss": loss / batch_size} | |
| # Single scalar per batch for logging to tensorboard in training script. | |
| output_dict["loss_components"] = { | |
| "classification": loss.clone().detach() / batch_size | |
| } | |
| # Return top-10 tokens according to log-probabilities during validation. | |
| # Useful for logging. | |
| if not self.training: | |
| top_logprobs, top_tokens = logprobs.topk(k=10, dim=1) | |
| output_dict["predictions"] = top_tokens | |
| return output_dict | |
| class TokenClassificationModel(ClassificationModel): | |
| r""" | |
| Convenient extension of :class:`~virtex.models.classification.ClassificationModel` | |
| for better readability (this only modifies the tensorboard logging logic). | |
| Ground truth targets here are a set of unique caption tokens (ignoring the | |
| special tokens like ``[SOS]``, ``[EOS]`` etc.). | |
| """ | |
| def log_predictions( | |
| self, batch: Dict[str, torch.Tensor], tokenizer: SentencePieceBPETokenizer | |
| ) -> str: | |
| self.eval() | |
| with torch.no_grad(): | |
| predictions = self.forward(batch)["predictions"] | |
| self.train() | |
| predictions_str = "" | |
| for tokens, preds in zip(batch["caption_tokens"], predictions): | |
| # Predictions here are individual tokens, and do not have any order | |
| # like captions, so decode them separately so we don't strip off | |
| # metaspace character and special tokens if any. | |
| preds = [tokenizer.id_to_token(p) for p in preds.tolist()] | |
| predictions_str += f""" | |
| Caption tokens : {tokenizer.decode(tokens.tolist())} | |
| Predictions (f): {" ".join(preds)} | |
| """ | |
| return predictions_str | |
| class MultiLabelClassificationModel(ClassificationModel): | |
| r""" | |
| Convenient extension of :class:`~virtex.models.classification.ClassificationModel` | |
| for better readability (this only modifies the tensorboard logging logic). | |
| Ground truth targets here are a set of unique instances in images (ignoring | |
| the special background token, category id = 0 in COCO). | |
| """ | |
| def log_predictions( | |
| self, | |
| batch: Dict[str, torch.Tensor], | |
| tokenizer: SentencePieceBPETokenizer = None, | |
| ) -> str: | |
| # We accept `tokenizer` for having consistent API but don't use it here. | |
| self.eval() | |
| with torch.no_grad(): | |
| predictions = self.forward(batch)["predictions"] | |
| self.train() | |
| predictions_str = "" | |
| for tokens, preds in zip(batch["caption_tokens"], predictions): | |
| # Predictions here are COCO category IDs, let them be as is. | |
| # Sorted ground truth, remove background tokens. | |
| tokens = sorted([t for t in tokens.tolist() if t != 0]) | |
| preds = sorted(preds.tolist()[: len(tokens)]) | |
| predictions_str += f""" | |
| COCO Instance IDs (GT) : {tokens} | |
| COCO Instance IDs (Pred) : {preds} | |
| """ | |
| return predictions_str | |