| | |
| | import warnings |
| | from pathlib import Path |
| | from typing import Optional, Union |
| |
|
| | import mmcv |
| | import numpy as np |
| | import torch |
| | from mmengine import Config |
| | from mmengine.registry import init_default_scope |
| | from mmengine.runner import load_checkpoint |
| | from mmengine.utils import mkdir_or_exist |
| |
|
| | from mmseg.models import BaseSegmentor |
| | from mmseg.registry import MODELS |
| | from mmseg.structures import SegDataSample |
| | from mmseg.utils import SampleList, dataset_aliases, get_classes, get_palette |
| | from mmseg.visualization import SegLocalVisualizer |
| | from .utils import ImageType, _preprare_data |
| |
|
| |
|
| | def init_model(config: Union[str, Path, Config], |
| | checkpoint: Optional[str] = None, |
| | device: str = 'cuda:0', |
| | cfg_options: Optional[dict] = None): |
| | """Initialize a segmentor from config file. |
| | |
| | Args: |
| | config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path, |
| | :obj:`Path`, or the config object. |
| | checkpoint (str, optional): Checkpoint path. If left as None, the model |
| | will not load any weights. |
| | device (str, optional) CPU/CUDA device option. Default 'cuda:0'. |
| | Use 'cpu' for loading model on CPU. |
| | cfg_options (dict, optional): Options to override some settings in |
| | the used config. |
| | Returns: |
| | nn.Module: The constructed segmentor. |
| | """ |
| | if isinstance(config, (str, Path)): |
| | config = Config.fromfile(config) |
| | elif not isinstance(config, Config): |
| | raise TypeError('config must be a filename or Config object, ' |
| | 'but got {}'.format(type(config))) |
| | if cfg_options is not None: |
| | config.merge_from_dict(cfg_options) |
| | if config.model.type == 'EncoderDecoder': |
| | if 'init_cfg' in config.model.backbone: |
| | config.model.backbone.init_cfg = None |
| | elif config.model.type == 'MultimodalEncoderDecoder': |
| | for k, v in config.model.items(): |
| | if isinstance(v, dict) and 'init_cfg' in v: |
| | config.model[k].init_cfg = None |
| | config.model.pretrained = None |
| | config.model.train_cfg = None |
| | init_default_scope(config.get('default_scope', 'mmseg')) |
| | |
| | model = MODELS.build(config.model) |
| | if checkpoint is not None: |
| | checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') |
| | dataset_meta = checkpoint['meta'].get('dataset_meta', None) |
| | |
| | if 'dataset_meta' in checkpoint.get('meta', {}): |
| | |
| | model.dataset_meta = dataset_meta |
| | elif 'CLASSES' in checkpoint.get('meta', {}): |
| | |
| | classes = checkpoint['meta']['CLASSES'] |
| | palette = checkpoint['meta']['PALETTE'] |
| | model.dataset_meta = {'classes': classes, 'palette': palette} |
| | else: |
| | warnings.simplefilter('once') |
| | warnings.warn( |
| | 'dataset_meta or class names are not saved in the ' |
| | 'checkpoint\'s meta data, classes and palette will be' |
| | 'set according to num_classes ') |
| | num_classes = model.decode_head.num_classes |
| | dataset_name = None |
| | for name in dataset_aliases.keys(): |
| | if len(get_classes(name)) == num_classes: |
| | dataset_name = name |
| | break |
| | if dataset_name is None: |
| | warnings.warn( |
| | 'No suitable dataset found, use Cityscapes by default') |
| | dataset_name = 'cityscapes' |
| | model.dataset_meta = { |
| | 'classes': get_classes(dataset_name), |
| | 'palette': get_palette(dataset_name) |
| | } |
| | model.cfg = config |
| | model.to(device) |
| | model.eval() |
| | return model |
| |
|
| |
|
| | def inference_model(model: BaseSegmentor, |
| | img: ImageType) -> Union[SegDataSample, SampleList]: |
| | """Inference image(s) with the segmentor. |
| | |
| | Args: |
| | model (nn.Module): The loaded segmentor. |
| | imgs (str/ndarray or list[str/ndarray]): Either image files or loaded |
| | images. |
| | |
| | Returns: |
| | :obj:`SegDataSample` or list[:obj:`SegDataSample`]: |
| | If imgs is a list or tuple, the same length list type results |
| | will be returned, otherwise return the segmentation results directly. |
| | """ |
| | |
| | data, is_batch = _preprare_data(img, model) |
| |
|
| | |
| | |
| | with torch.no_grad(): |
| | results = model.test_step(data) |
| |
|
| | return results if is_batch else results[0] |
| |
|
| |
|
| | def show_result_pyplot(model: BaseSegmentor, |
| | img: Union[str, np.ndarray], |
| | result: SegDataSample, |
| | opacity: float = 0.5, |
| | title: str = '', |
| | draw_gt: bool = True, |
| | draw_pred: bool = True, |
| | wait_time: float = 0, |
| | show: bool = True, |
| | with_labels: Optional[bool] = True, |
| | save_dir=None, |
| | out_file=None): |
| | """Visualize the segmentation results on the image. |
| | |
| | Args: |
| | model (nn.Module): The loaded segmentor. |
| | img (str or np.ndarray): Image filename or loaded image. |
| | result (SegDataSample): The prediction SegDataSample result. |
| | opacity(float): Opacity of painted segmentation map. |
| | Default 0.5. Must be in (0, 1] range. |
| | title (str): The title of pyplot figure. |
| | Default is ''. |
| | draw_gt (bool): Whether to draw GT SegDataSample. Default to True. |
| | draw_pred (bool): Whether to draw Prediction SegDataSample. |
| | Defaults to True. |
| | wait_time (float): The interval of show (s). 0 is the special value |
| | that means "forever". Defaults to 0. |
| | show (bool): Whether to display the drawn image. |
| | Default to True. |
| | with_labels(bool, optional): Add semantic labels in visualization |
| | result, Default to True. |
| | save_dir (str, optional): Save file dir for all storage backends. |
| | If it is None, the backend storage will not save any data. |
| | out_file (str, optional): Path to output file. Default to None. |
| | |
| | |
| | |
| | Returns: |
| | np.ndarray: the drawn image which channel is RGB. |
| | """ |
| | if hasattr(model, 'module'): |
| | model = model.module |
| | if isinstance(img, str): |
| | image = mmcv.imread(img, channel_order='rgb') |
| | else: |
| | image = img |
| | if save_dir is not None: |
| | mkdir_or_exist(save_dir) |
| | |
| | visualizer = SegLocalVisualizer( |
| | vis_backends=[dict(type='LocalVisBackend')], |
| | save_dir=save_dir, |
| | alpha=opacity) |
| | visualizer.dataset_meta = dict( |
| | classes=model.dataset_meta['classes'], |
| | palette=model.dataset_meta['palette']) |
| | visualizer.add_datasample( |
| | name=title, |
| | image=image, |
| | data_sample=result, |
| | draw_gt=draw_gt, |
| | draw_pred=draw_pred, |
| | wait_time=wait_time, |
| | out_file=out_file, |
| | show=show, |
| | with_labels=with_labels) |
| | vis_img = visualizer.get_image() |
| |
|
| | return vis_img |
| |
|