| | import numpy as np |
| | import torch |
| | from mmcv import ConfigDict |
| | from torch import nn |
| |
|
| | from mmseg.models import BACKBONES, HEADS, build_segmentor |
| | from mmseg.models.decode_heads.cascade_decode_head import BaseCascadeDecodeHead |
| | from mmseg.models.decode_heads.decode_head import BaseDecodeHead |
| |
|
| |
|
| | def _demo_mm_inputs(input_shape=(1, 3, 8, 16), num_classes=10): |
| | """Create a superset of inputs needed to run test or train batches. |
| | |
| | Args: |
| | input_shape (tuple): |
| | input batch dimensions |
| | |
| | num_classes (int): |
| | number of semantic classes |
| | """ |
| | (N, C, H, W) = input_shape |
| |
|
| | rng = np.random.RandomState(0) |
| |
|
| | imgs = rng.rand(*input_shape) |
| | segs = rng.randint( |
| | low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8) |
| |
|
| | img_metas = [{ |
| | 'img_shape': (H, W, C), |
| | 'ori_shape': (H, W, C), |
| | 'pad_shape': (H, W, C), |
| | 'filename': '<demo>.png', |
| | 'scale_factor': 1.0, |
| | 'flip': False, |
| | 'flip_direction': 'horizontal' |
| | } for _ in range(N)] |
| |
|
| | mm_inputs = { |
| | 'imgs': torch.FloatTensor(imgs), |
| | 'img_metas': img_metas, |
| | 'gt_semantic_seg': torch.LongTensor(segs) |
| | } |
| | return mm_inputs |
| |
|
| |
|
| | @BACKBONES.register_module() |
| | class ExampleBackbone(nn.Module): |
| |
|
| | def __init__(self): |
| | super(ExampleBackbone, self).__init__() |
| | self.conv = nn.Conv2d(3, 3, 3) |
| |
|
| | def init_weights(self, pretrained=None): |
| | pass |
| |
|
| | def forward(self, x): |
| | return [self.conv(x)] |
| |
|
| |
|
| | @HEADS.register_module() |
| | class ExampleDecodeHead(BaseDecodeHead): |
| |
|
| | def __init__(self): |
| | super(ExampleDecodeHead, self).__init__(3, 3, num_classes=19) |
| |
|
| | def forward(self, inputs): |
| | return self.cls_seg(inputs[0]) |
| |
|
| |
|
| | @HEADS.register_module() |
| | class ExampleCascadeDecodeHead(BaseCascadeDecodeHead): |
| |
|
| | def __init__(self): |
| | super(ExampleCascadeDecodeHead, self).__init__(3, 3, num_classes=19) |
| |
|
| | def forward(self, inputs, prev_out): |
| | return self.cls_seg(inputs[0]) |
| |
|
| |
|
| | def _segmentor_forward_train_test(segmentor): |
| | if isinstance(segmentor.decode_head, nn.ModuleList): |
| | num_classes = segmentor.decode_head[-1].num_classes |
| | else: |
| | num_classes = segmentor.decode_head.num_classes |
| | |
| | mm_inputs = _demo_mm_inputs(num_classes=num_classes) |
| |
|
| | imgs = mm_inputs.pop('imgs') |
| | img_metas = mm_inputs.pop('img_metas') |
| | gt_semantic_seg = mm_inputs['gt_semantic_seg'] |
| |
|
| | |
| | if torch.cuda.is_available(): |
| | segmentor = segmentor.cuda() |
| | imgs = imgs.cuda() |
| | gt_semantic_seg = gt_semantic_seg.cuda() |
| |
|
| | |
| | losses = segmentor.forward( |
| | imgs, img_metas, gt_semantic_seg=gt_semantic_seg, return_loss=True) |
| | assert isinstance(losses, dict) |
| |
|
| | |
| | with torch.no_grad(): |
| | segmentor.eval() |
| | |
| | img_list = [img[None, :] for img in imgs] |
| | img_meta_list = [[img_meta] for img_meta in img_metas] |
| | segmentor.forward(img_list, img_meta_list, return_loss=False) |
| |
|
| | |
| | with torch.no_grad(): |
| | segmentor.eval() |
| | |
| | img_list = [img[None, :] for img in imgs] |
| | img_list = img_list + img_list |
| | img_meta_list = [[img_meta] for img_meta in img_metas] |
| | img_meta_list = img_meta_list + img_meta_list |
| | segmentor.forward(img_list, img_meta_list, return_loss=False) |
| |
|
| |
|
| | def test_encoder_decoder(): |
| |
|
| | |
| |
|
| | cfg = ConfigDict( |
| | type='EncoderDecoder', |
| | backbone=dict(type='ExampleBackbone'), |
| | decode_head=dict(type='ExampleDecodeHead'), |
| | train_cfg=None, |
| | test_cfg=dict(mode='whole')) |
| | segmentor = build_segmentor(cfg) |
| | _segmentor_forward_train_test(segmentor) |
| |
|
| | |
| | cfg.test_cfg = ConfigDict(mode='slide', crop_size=(3, 3), stride=(2, 2)) |
| | segmentor = build_segmentor(cfg) |
| | _segmentor_forward_train_test(segmentor) |
| |
|
| | |
| | cfg = ConfigDict( |
| | type='EncoderDecoder', |
| | backbone=dict(type='ExampleBackbone'), |
| | decode_head=dict(type='ExampleDecodeHead'), |
| | auxiliary_head=dict(type='ExampleDecodeHead')) |
| | cfg.test_cfg = ConfigDict(mode='whole') |
| | segmentor = build_segmentor(cfg) |
| | _segmentor_forward_train_test(segmentor) |
| |
|
| | |
| | cfg = ConfigDict( |
| | type='EncoderDecoder', |
| | backbone=dict(type='ExampleBackbone'), |
| | decode_head=dict(type='ExampleDecodeHead'), |
| | auxiliary_head=[ |
| | dict(type='ExampleDecodeHead'), |
| | dict(type='ExampleDecodeHead') |
| | ]) |
| | cfg.test_cfg = ConfigDict(mode='whole') |
| | segmentor = build_segmentor(cfg) |
| | _segmentor_forward_train_test(segmentor) |
| |
|
| |
|
| | def test_cascade_encoder_decoder(): |
| |
|
| | |
| | cfg = ConfigDict( |
| | type='CascadeEncoderDecoder', |
| | num_stages=2, |
| | backbone=dict(type='ExampleBackbone'), |
| | decode_head=[ |
| | dict(type='ExampleDecodeHead'), |
| | dict(type='ExampleCascadeDecodeHead') |
| | ]) |
| | cfg.test_cfg = ConfigDict(mode='whole') |
| | segmentor = build_segmentor(cfg) |
| | _segmentor_forward_train_test(segmentor) |
| |
|
| | |
| | cfg.test_cfg = ConfigDict(mode='slide', crop_size=(3, 3), stride=(2, 2)) |
| | segmentor = build_segmentor(cfg) |
| | _segmentor_forward_train_test(segmentor) |
| |
|
| | |
| | cfg = ConfigDict( |
| | type='CascadeEncoderDecoder', |
| | num_stages=2, |
| | backbone=dict(type='ExampleBackbone'), |
| | decode_head=[ |
| | dict(type='ExampleDecodeHead'), |
| | dict(type='ExampleCascadeDecodeHead') |
| | ], |
| | auxiliary_head=dict(type='ExampleDecodeHead')) |
| | cfg.test_cfg = ConfigDict(mode='whole') |
| | segmentor = build_segmentor(cfg) |
| | _segmentor_forward_train_test(segmentor) |
| |
|
| | |
| | cfg = ConfigDict( |
| | type='CascadeEncoderDecoder', |
| | num_stages=2, |
| | backbone=dict(type='ExampleBackbone'), |
| | decode_head=[ |
| | dict(type='ExampleDecodeHead'), |
| | dict(type='ExampleCascadeDecodeHead') |
| | ], |
| | auxiliary_head=[ |
| | dict(type='ExampleDecodeHead'), |
| | dict(type='ExampleDecodeHead') |
| | ]) |
| | cfg.test_cfg = ConfigDict(mode='whole') |
| | segmentor = build_segmentor(cfg) |
| | _segmentor_forward_train_test(segmentor) |
| |
|