Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
| import logging | |
| import numpy as np | |
| from typing import Dict, List, Optional, Tuple | |
| import torch | |
| from torch import nn | |
| from detectron2.config import configurable | |
| from detectron2.data.detection_utils import convert_image_to_rgb | |
| from detectron2.structures import ImageList, Instances | |
| from detectron2.utils.events import get_event_storage | |
| from detectron2.utils.logger import log_first_n | |
| from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY | |
| from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN | |
| from .Wordnn_embedding import WordnnEmbedding | |
| __all__ = ["VGT"] | |
| def torch_memory(device, tag=""): | |
| # Checks and prints GPU memory | |
| print(tag, f"{torch.cuda.memory_allocated(device)/1024/1024:.2f} MB USED") | |
| print(tag, f"{torch.cuda.memory_reserved(device)/1024/1024:.2f} MB RESERVED") | |
| print(tag, f"{torch.cuda.max_memory_allocated(device)/1024/1024:.2f} MB USED MAX") | |
| print(tag, f"{torch.cuda.max_memory_reserved(device)/1024/1024:.2f} MB RESERVED MAX") | |
| print("") | |
| class VGT(GeneralizedRCNN): | |
| def __init__( | |
| self, | |
| *, | |
| vocab_size: int = 30552, | |
| hidden_size: int = 768, | |
| embedding_dim: int = 64, | |
| bros_embedding_path: str = "", | |
| use_pretrain_weight: bool = True, | |
| use_UNK_text: bool = False, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.vocab_size = vocab_size | |
| self.embedding_dim = embedding_dim | |
| self.Wordgrid_embedding = WordnnEmbedding( | |
| vocab_size, hidden_size, embedding_dim, bros_embedding_path, use_pretrain_weight, use_UNK_text | |
| ) | |
| def from_config(cls, cfg): | |
| ret = super().from_config(cfg) | |
| ret.update( | |
| { | |
| "vocab_size": cfg.MODEL.WORDGRID.VOCAB_SIZE, | |
| "hidden_size": cfg.MODEL.WORDGRID.HIDDEN_SIZE, | |
| "embedding_dim": cfg.MODEL.WORDGRID.EMBEDDING_DIM, | |
| "bros_embedding_path": cfg.MODEL.WORDGRID.MODEL_PATH, | |
| "use_pretrain_weight": cfg.MODEL.WORDGRID.USE_PRETRAIN_WEIGHT, | |
| "use_UNK_text": cfg.MODEL.WORDGRID.USE_UNK_TEXT, | |
| } | |
| ) | |
| return ret | |
| def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]): | |
| """ | |
| Args: | |
| batched_inputs: a list, batched outputs of :class:`DatasetMapper` . | |
| Each item in the list contains the inputs for one image. | |
| For now, each item in the list is a dict that contains: | |
| * image: Tensor, image in (C, H, W) format. | |
| * instances (optional): groundtruth :class:`Instances` | |
| * proposals (optional): :class:`Instances`, precomputed proposals. | |
| Other information that's included in the original dicts, such as: | |
| * "height", "width" (int): the output resolution of the model, used in inference. | |
| See :meth:`postprocess` for details. | |
| Returns: | |
| list[dict]: | |
| Each dict is the output for one input image. | |
| The dict contains one key "instances" whose value is a :class:`Instances`. | |
| The :class:`Instances` object has the following keys: | |
| "pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints" | |
| """ | |
| if not self.training: | |
| return self.inference(batched_inputs) | |
| images = self.preprocess_image(batched_inputs) | |
| if "instances" in batched_inputs[0]: | |
| gt_instances = [x["instances"].to(self.device) for x in batched_inputs] | |
| else: | |
| gt_instances = None | |
| chargrid = self.Wordgrid_embedding(images.tensor, batched_inputs) | |
| features = self.backbone(images.tensor, chargrid) | |
| if self.proposal_generator is not None: | |
| proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) | |
| else: | |
| assert "proposals" in batched_inputs[0] | |
| proposals = [x["proposals"].to(self.device) for x in batched_inputs] | |
| proposal_losses = {} | |
| _, detector_losses = self.roi_heads(images, features, proposals, gt_instances) | |
| if self.vis_period > 0: | |
| storage = get_event_storage() | |
| if storage.iter % self.vis_period == 0: | |
| self.visualize_training(batched_inputs, proposals) | |
| losses = {} | |
| losses.update(detector_losses) | |
| losses.update(proposal_losses) | |
| return losses | |
| def inference( | |
| self, | |
| batched_inputs: List[Dict[str, torch.Tensor]], | |
| detected_instances: Optional[List[Instances]] = None, | |
| do_postprocess: bool = True, | |
| ): | |
| """ | |
| Run inference on the given inputs. | |
| Args: | |
| batched_inputs (list[dict]): same as in :meth:`forward` | |
| detected_instances (None or list[Instances]): if not None, it | |
| contains an `Instances` object per image. The `Instances` | |
| object contains "pred_boxes" and "pred_classes" which are | |
| known boxes in the image. | |
| The inference will then skip the detection of bounding boxes, | |
| and only predict other per-ROI outputs. | |
| do_postprocess (bool): whether to apply post-processing on the outputs. | |
| Returns: | |
| When do_postprocess=True, same as in :meth:`forward`. | |
| Otherwise, a list[Instances] containing raw network outputs. | |
| """ | |
| assert not self.training | |
| images = self.preprocess_image(batched_inputs) | |
| chargrid = self.Wordgrid_embedding(images.tensor, batched_inputs) | |
| features = self.backbone(images.tensor, chargrid) | |
| if detected_instances is None: | |
| if self.proposal_generator is not None: | |
| proposals, _ = self.proposal_generator(images, features, None) | |
| else: | |
| assert "proposals" in batched_inputs[0] | |
| proposals = [x["proposals"].to(self.device) for x in batched_inputs] | |
| results, _ = self.roi_heads(images, features, proposals, None) | |
| else: | |
| detected_instances = [x.to(self.device) for x in detected_instances] | |
| results = self.roi_heads.forward_with_given_boxes(features, detected_instances) | |
| if do_postprocess: | |
| assert not torch.jit.is_scripting(), "Scripting is not supported for postprocess." | |
| return GeneralizedRCNN._postprocess(results, batched_inputs, images.image_sizes) | |
| else: | |
| return results | |