Spaces:
Runtime error
Runtime error
| import os.path as osp | |
| import warnings | |
| from typing import Optional, Sequence, Any | |
| import mmcv | |
| from lightning import Callback | |
| from mmengine.fileio import get | |
| from mmengine.hooks import Hook | |
| from mmengine.runner import Runner | |
| from mmengine.utils import mkdir_or_exist | |
| from mmengine.visualization import Visualizer | |
| from mmpl.registry import HOOKS | |
| from mmdet.structures import DetDataSample | |
| class DetVisualizationHook(Callback): | |
| """Detection Visualization Hook. Used to visualize validation and testing | |
| process prediction results. | |
| In the testing phase: | |
| 1. If ``show`` is True, it means that only the prediction results are | |
| visualized without storing data, so ``vis_backends`` needs to | |
| be excluded. | |
| 2. If ``test_out_dir`` is specified, it means that the prediction results | |
| need to be saved to ``test_out_dir``. In order to avoid vis_backends | |
| also storing data, so ``vis_backends`` needs to be excluded. | |
| 3. ``vis_backends`` takes effect if the user does not specify ``show`` | |
| and `test_out_dir``. You can set ``vis_backends`` to WandbVisBackend or | |
| TensorboardVisBackend to store the prediction result in Wandb or | |
| Tensorboard. | |
| Args: | |
| draw (bool): whether to draw prediction results. If it is False, | |
| it means that no drawing will be done. Defaults to False. | |
| interval (int): The interval of visualization. Defaults to 50. | |
| score_thr (float): The threshold to visualize the bboxes | |
| and masks. Defaults to 0.3. | |
| show (bool): Whether to display the drawn image. Default to False. | |
| wait_time (float): The interval of show (s). Defaults to 0. | |
| test_out_dir (str, optional): directory where painted images | |
| will be saved in testing process. | |
| backend_args (dict, optional): Arguments to instantiate the | |
| corresponding backend. Defaults to None. | |
| """ | |
| def __init__(self, | |
| draw: bool = False, | |
| interval: int = 50, | |
| score_thr: float = 0.3, | |
| show: bool = False, | |
| wait_time: float = 0., | |
| test_out_dir: Optional[str] = None, | |
| backend_args: dict = None): | |
| self._visualizer: Visualizer = Visualizer.get_current_instance() | |
| self.interval = interval | |
| self.score_thr = score_thr | |
| self.show = show | |
| if self.show: | |
| # No need to think about vis backends. | |
| self._visualizer._vis_backends = {} | |
| warnings.warn('The show is True, it means that only ' | |
| 'the prediction results are visualized ' | |
| 'without storing data, so vis_backends ' | |
| 'needs to be excluded.') | |
| self.wait_time = wait_time | |
| self.backend_args = backend_args | |
| self.draw = draw | |
| self.test_out_dir = test_out_dir | |
| self._test_index = 0 | |
| def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict, | |
| outputs: Sequence[DetDataSample]) -> None: | |
| """Run after every ``self.interval`` validation iterations. | |
| Args: | |
| runner (:obj:`Runner`): The runner of the validation process. | |
| batch_idx (int): The index of the current batch in the val loop. | |
| data_batch (dict): Data from dataloader. | |
| outputs (Sequence[:obj:`DetDataSample`]]): A batch of data samples | |
| that contain annotations and predictions. | |
| """ | |
| if self.draw is False: | |
| return | |
| # There is no guarantee that the same batch of images | |
| # is visualized for each evaluation. | |
| total_curr_iter = runner.iter + batch_idx | |
| # Visualize only the first data | |
| img_path = outputs[0].img_path | |
| img_bytes = get(img_path, backend_args=self.backend_args) | |
| img = mmcv.imfrombytes(img_bytes, channel_order='rgb') | |
| if total_curr_iter % self.interval == 0: | |
| self._visualizer.add_datasample( | |
| osp.basename(img_path) if self.show else 'val_img', | |
| img, | |
| data_sample=outputs[0], | |
| show=self.show, | |
| wait_time=self.wait_time, | |
| pred_score_thr=self.score_thr, | |
| step=total_curr_iter) | |
| def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, | |
| outputs: Sequence[DetDataSample]) -> None: | |
| """Run after every testing iterations. | |
| Args: | |
| runner (:obj:`Runner`): The runner of the testing process. | |
| batch_idx (int): The index of the current batch in the val loop. | |
| data_batch (dict): Data from dataloader. | |
| outputs (Sequence[:obj:`DetDataSample`]): A batch of data samples | |
| that contain annotations and predictions. | |
| """ | |
| if self.draw is False: | |
| return | |
| if self.test_out_dir is not None: | |
| self.test_out_dir = osp.join(runner.work_dir, runner.timestamp, | |
| self.test_out_dir) | |
| mkdir_or_exist(self.test_out_dir) | |
| for data_sample in outputs: | |
| self._test_index += 1 | |
| img_path = data_sample.img_path | |
| img_bytes = get(img_path, backend_args=self.backend_args) | |
| img = mmcv.imfrombytes(img_bytes, channel_order='rgb') | |
| out_file = None | |
| if self.test_out_dir is not None: | |
| out_file = osp.basename(img_path) | |
| out_file = osp.join(self.test_out_dir, out_file) | |
| self._visualizer.add_datasample( | |
| osp.basename(img_path) if self.show else 'test_img', | |
| img, | |
| data_sample=data_sample, | |
| show=self.show, | |
| wait_time=self.wait_time, | |
| pred_score_thr=self.score_thr, | |
| out_file=out_file, | |
| step=self._test_index) | |
| def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | |
| # if hasattr(trainer.datamodule, f'predict_dataset'): | |
| # dataset = getattr(trainer.datamodule, f'predict_dataset') | |
| # if hasattr(dataset, 'metainfo') and hasattr(self._visualizer, 'dataset_meta'): | |
| # self._visualizer.dataset_meta = dataset.metainfo | |
| if self.test_out_dir is not None: | |
| self.test_out_dir = osp.join(trainer.default_root_dir, self.test_out_dir) | |
| mkdir_or_exist(self.test_out_dir) | |
| def on_predict_batch_end( | |
| self, | |
| trainer: "pl.Trainer", | |
| pl_module: "pl.LightningModule", | |
| outputs: Any, | |
| batch: Any, | |
| batch_idx: int, | |
| dataloader_idx: int = 0, | |
| ) -> None: | |
| """Run after every testing iterations. | |
| Args: | |
| runner (:obj:`Runner`): The runner of the testing process. | |
| batch_idx (int): The index of the current batch in the val loop. | |
| data_batch (dict): Data from dataloader. | |
| outputs (Sequence[:obj:`DetDataSample`]): A batch of data samples | |
| that contain annotations and predictions. | |
| """ | |
| if self.draw is False: | |
| return | |
| for data_sample in outputs: | |
| self._test_index += 1 | |
| img_path = data_sample.img_path | |
| img_bytes = get(img_path, backend_args=self.backend_args) | |
| img = mmcv.imfrombytes(img_bytes, channel_order='rgb') | |
| out_file = None | |
| if self.test_out_dir is not None: | |
| out_file = osp.basename(img_path) | |
| out_file = osp.join(self.test_out_dir, out_file) | |
| self._visualizer.add_datasample( | |
| osp.basename(img_path) if self.show else 'test_img', | |
| img, | |
| data_sample=data_sample, | |
| show=self.show, | |
| wait_time=self.wait_time, | |
| pred_score_thr=self.score_thr, | |
| out_file=out_file, | |
| step=self._test_index) | |