| | import logging |
| | import warnings |
| | from abc import ABCMeta, abstractmethod |
| | from collections import OrderedDict |
| |
|
| | import mmcv |
| | import numpy as np |
| | import torch |
| | import torch.distributed as dist |
| | import torch.nn as nn |
| | from mmcv.runner import auto_fp16 |
| |
|
| |
|
| | class BaseSegmentor(nn.Module): |
| | """Base class for segmentors.""" |
| |
|
| | __metaclass__ = ABCMeta |
| |
|
| | def __init__(self): |
| | super(BaseSegmentor, self).__init__() |
| | self.fp16_enabled = False |
| |
|
| | @property |
| | def with_neck(self): |
| | """bool: whether the segmentor has neck""" |
| | return hasattr(self, 'neck') and self.neck is not None |
| |
|
| | @property |
| | def with_auxiliary_head(self): |
| | """bool: whether the segmentor has auxiliary head""" |
| | return hasattr(self, |
| | 'auxiliary_head') and self.auxiliary_head is not None |
| |
|
| | @property |
| | def with_decode_head(self): |
| | """bool: whether the segmentor has decode head""" |
| | return hasattr(self, 'decode_head') and self.decode_head is not None |
| |
|
| | @abstractmethod |
| | def extract_feat(self, imgs): |
| | """Placeholder for extract features from images.""" |
| | pass |
| |
|
| | @abstractmethod |
| | def encode_decode(self, img, img_metas): |
| | """Placeholder for encode images with backbone and decode into a |
| | semantic segmentation map of the same size as input.""" |
| | pass |
| |
|
| | @abstractmethod |
| | def forward_train(self, imgs, img_metas, **kwargs): |
| | """Placeholder for Forward function for training.""" |
| | pass |
| |
|
| | @abstractmethod |
| | def simple_test(self, img, img_meta, **kwargs): |
| | """Placeholder for single image test.""" |
| | pass |
| |
|
| | @abstractmethod |
| | def aug_test(self, imgs, img_metas, **kwargs): |
| | """Placeholder for augmentation test.""" |
| | pass |
| |
|
| | def init_weights(self, pretrained=None): |
| | """Initialize the weights in segmentor. |
| | |
| | Args: |
| | pretrained (str, optional): Path to pre-trained weights. |
| | Defaults to None. |
| | """ |
| | if pretrained is not None: |
| | logger = logging.getLogger() |
| | logger.info(f'load model from: {pretrained}') |
| |
|
| | def forward_test(self, imgs, img_metas, **kwargs): |
| | """ |
| | Args: |
| | imgs (List[Tensor]): the outer list indicates test-time |
| | augmentations and inner Tensor should have a shape NxCxHxW, |
| | which contains all images in the batch. |
| | img_metas (List[List[dict]]): the outer list indicates test-time |
| | augs (multiscale, flip, etc.) and the inner list indicates |
| | images in a batch. |
| | """ |
| | for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]: |
| | if not isinstance(var, list): |
| | raise TypeError(f'{name} must be a list, but got ' |
| | f'{type(var)}') |
| |
|
| | num_augs = len(imgs) |
| | if num_augs != len(img_metas): |
| | raise ValueError(f'num of augmentations ({len(imgs)}) != ' |
| | f'num of image meta ({len(img_metas)})') |
| | |
| | |
| | for img_meta in img_metas: |
| | ori_shapes = [_['ori_shape'] for _ in img_meta] |
| | assert all(shape == ori_shapes[0] for shape in ori_shapes) |
| | img_shapes = [_['img_shape'] for _ in img_meta] |
| | assert all(shape == img_shapes[0] for shape in img_shapes) |
| | pad_shapes = [_['pad_shape'] for _ in img_meta] |
| | assert all(shape == pad_shapes[0] for shape in pad_shapes) |
| |
|
| | if num_augs == 1: |
| | return self.simple_test(imgs[0], img_metas[0], **kwargs) |
| | else: |
| | return self.aug_test(imgs, img_metas, **kwargs) |
| |
|
| | @auto_fp16(apply_to=('img', )) |
| | def forward(self, img, img_metas, return_loss=True, **kwargs): |
| | """Calls either :func:`forward_train` or :func:`forward_test` depending |
| | on whether ``return_loss`` is ``True``. |
| | |
| | Note this setting will change the expected inputs. When |
| | ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor |
| | and List[dict]), and when ``resturn_loss=False``, img and img_meta |
| | should be double nested (i.e. List[Tensor], List[List[dict]]), with |
| | the outer list indicating test time augmentations. |
| | """ |
| | if return_loss: |
| | return self.forward_train(img, img_metas, **kwargs) |
| | else: |
| | return self.forward_test(img, img_metas, **kwargs) |
| |
|
| | def train_step(self, data_batch, optimizer, **kwargs): |
| | """The iteration step during training. |
| | |
| | This method defines an iteration step during training, except for the |
| | back propagation and optimizer updating, which are done in an optimizer |
| | hook. Note that in some complicated cases or models, the whole process |
| | including back propagation and optimizer updating is also defined in |
| | this method, such as GAN. |
| | |
| | Args: |
| | data (dict): The output of dataloader. |
| | optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of |
| | runner is passed to ``train_step()``. This argument is unused |
| | and reserved. |
| | |
| | Returns: |
| | dict: It should contain at least 3 keys: ``loss``, ``log_vars``, |
| | ``num_samples``. |
| | ``loss`` is a tensor for back propagation, which can be a |
| | weighted sum of multiple losses. |
| | ``log_vars`` contains all the variables to be sent to the |
| | logger. |
| | ``num_samples`` indicates the batch size (when the model is |
| | DDP, it means the batch size on each GPU), which is used for |
| | averaging the logs. |
| | """ |
| | losses = self(**data_batch) |
| | loss, log_vars = self._parse_losses(losses) |
| |
|
| | outputs = dict( |
| | loss=loss, |
| | log_vars=log_vars, |
| | num_samples=len(data_batch['img'].data)) |
| |
|
| | return outputs |
| |
|
| | def val_step(self, data_batch, **kwargs): |
| | """The iteration step during validation. |
| | |
| | This method shares the same signature as :func:`train_step`, but used |
| | during val epochs. Note that the evaluation after training epochs is |
| | not implemented with this method, but an evaluation hook. |
| | """ |
| | output = self(**data_batch, **kwargs) |
| | return output |
| |
|
| | @staticmethod |
| | def _parse_losses(losses): |
| | """Parse the raw outputs (losses) of the network. |
| | |
| | Args: |
| | losses (dict): Raw output of the network, which usually contain |
| | losses and other necessary information. |
| | |
| | Returns: |
| | tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor |
| | which may be a weighted sum of all losses, log_vars contains |
| | all the variables to be sent to the logger. |
| | """ |
| | log_vars = OrderedDict() |
| | for loss_name, loss_value in losses.items(): |
| | if isinstance(loss_value, torch.Tensor): |
| | log_vars[loss_name] = loss_value.mean() |
| | elif isinstance(loss_value, list): |
| | log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) |
| | else: |
| | raise TypeError( |
| | f'{loss_name} is not a tensor or list of tensors') |
| |
|
| | loss = sum(_value for _key, _value in log_vars.items() |
| | if 'loss' in _key) |
| |
|
| | log_vars['loss'] = loss |
| | for loss_name, loss_value in log_vars.items(): |
| | |
| | if dist.is_available() and dist.is_initialized(): |
| | loss_value = loss_value.data.clone() |
| | dist.all_reduce(loss_value.div_(dist.get_world_size())) |
| | log_vars[loss_name] = loss_value.item() |
| |
|
| | return loss, log_vars |
| |
|
| | def show_result(self, |
| | img, |
| | result, |
| | palette=None, |
| | win_name='', |
| | show=False, |
| | wait_time=0, |
| | out_file=None): |
| | """Draw `result` over `img`. |
| | |
| | Args: |
| | img (str or Tensor): The image to be displayed. |
| | result (Tensor): The semantic segmentation results to draw over |
| | `img`. |
| | palette (list[list[int]]] | np.ndarray | None): The palette of |
| | segmentation map. If None is given, random palette will be |
| | generated. Default: None |
| | win_name (str): The window name. |
| | wait_time (int): Value of waitKey param. |
| | Default: 0. |
| | show (bool): Whether to show the image. |
| | Default: False. |
| | out_file (str or None): The filename to write the image. |
| | Default: None. |
| | |
| | Returns: |
| | img (Tensor): Only if not `show` or `out_file` |
| | """ |
| | img = mmcv.imread(img) |
| | img = img.copy() |
| | seg = result[0] |
| | if palette is None: |
| | if self.PALETTE is None: |
| | palette = np.random.randint( |
| | 0, 255, size=(len(self.CLASSES), 3)) |
| | else: |
| | palette = self.PALETTE |
| | palette = np.array(palette) |
| | assert palette.shape[0] == len(self.CLASSES) |
| | assert palette.shape[1] == 3 |
| | assert len(palette.shape) == 2 |
| | color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) |
| | for label, color in enumerate(palette): |
| | color_seg[seg == label, :] = color |
| | |
| | color_seg = color_seg[..., ::-1] |
| |
|
| | img = img * 0.5 + color_seg * 0.5 |
| | img = img.astype(np.uint8) |
| | |
| | if out_file is not None: |
| | show = False |
| |
|
| | if show: |
| | mmcv.imshow(img, win_name, wait_time) |
| | if out_file is not None: |
| | mmcv.imwrite(img, out_file) |
| |
|
| | if not (show or out_file): |
| | warnings.warn('show==False and out_file is not specified, only ' |
| | 'result image will be returned') |
| | return img |
| |
|