Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import logging | |
| import os | |
| import os.path as osp | |
| import platform | |
| import random | |
| import string | |
| import tempfile | |
| import pytest | |
| import torch | |
| import torch.nn as nn | |
| from mmcv.parallel import MMDataParallel | |
| from mmcv.runner import (RUNNERS, EpochBasedRunner, IterBasedRunner, | |
| build_runner) | |
| from mmcv.runner.hooks import IterTimerHook | |
| class OldStyleModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.conv = nn.Conv2d(3, 3, 1) | |
| class Model(OldStyleModel): | |
| def train_step(self): | |
| pass | |
| def val_step(self): | |
| pass | |
| def test_build_runner(): | |
| temp_root = tempfile.gettempdir() | |
| dir_name = ''.join( | |
| [random.choice(string.ascii_letters) for _ in range(10)]) | |
| default_args = dict( | |
| model=Model(), | |
| work_dir=osp.join(temp_root, dir_name), | |
| logger=logging.getLogger()) | |
| cfg = dict(type='EpochBasedRunner', max_epochs=1) | |
| runner = build_runner(cfg, default_args=default_args) | |
| assert runner._max_epochs == 1 | |
| cfg = dict(type='IterBasedRunner', max_iters=1) | |
| runner = build_runner(cfg, default_args=default_args) | |
| assert runner._max_iters == 1 | |
| with pytest.raises(ValueError, match='Only one of'): | |
| cfg = dict(type='IterBasedRunner', max_epochs=1, max_iters=1) | |
| runner = build_runner(cfg, default_args=default_args) | |
| def test_epoch_based_runner(runner_class): | |
| with pytest.warns(DeprecationWarning): | |
| # batch_processor is deprecated | |
| model = OldStyleModel() | |
| def batch_processor(): | |
| pass | |
| _ = runner_class(model, batch_processor, logger=logging.getLogger()) | |
| with pytest.raises(TypeError): | |
| # batch_processor must be callable | |
| model = OldStyleModel() | |
| _ = runner_class(model, batch_processor=0, logger=logging.getLogger()) | |
| with pytest.raises(TypeError): | |
| # optimizer must be a optimizer or a dict of optimizers | |
| model = Model() | |
| optimizer = 'NotAOptimizer' | |
| _ = runner_class( | |
| model, optimizer=optimizer, logger=logging.getLogger()) | |
| with pytest.raises(TypeError): | |
| # optimizer must be a optimizer or a dict of optimizers | |
| model = Model() | |
| optimizers = dict(optim1=torch.optim.Adam(), optim2='NotAOptimizer') | |
| _ = runner_class( | |
| model, optimizer=optimizers, logger=logging.getLogger()) | |
| with pytest.raises(TypeError): | |
| # logger must be a logging.Logger | |
| model = Model() | |
| _ = runner_class(model, logger=None) | |
| with pytest.raises(TypeError): | |
| # meta must be a dict or None | |
| model = Model() | |
| _ = runner_class(model, logger=logging.getLogger(), meta=['list']) | |
| with pytest.raises(AssertionError): | |
| # model must implement the method train_step() | |
| model = OldStyleModel() | |
| _ = runner_class(model, logger=logging.getLogger()) | |
| with pytest.raises(TypeError): | |
| # work_dir must be a str or None | |
| model = Model() | |
| _ = runner_class(model, work_dir=1, logger=logging.getLogger()) | |
| with pytest.raises(RuntimeError): | |
| # batch_processor and train_step() cannot be both set | |
| def batch_processor(): | |
| pass | |
| model = Model() | |
| _ = runner_class(model, batch_processor, logger=logging.getLogger()) | |
| # test work_dir | |
| model = Model() | |
| temp_root = tempfile.gettempdir() | |
| dir_name = ''.join( | |
| [random.choice(string.ascii_letters) for _ in range(10)]) | |
| work_dir = osp.join(temp_root, dir_name) | |
| _ = runner_class(model, work_dir=work_dir, logger=logging.getLogger()) | |
| assert osp.isdir(work_dir) | |
| _ = runner_class(model, work_dir=work_dir, logger=logging.getLogger()) | |
| assert osp.isdir(work_dir) | |
| os.removedirs(work_dir) | |
| def test_runner_with_parallel(runner_class): | |
| def batch_processor(): | |
| pass | |
| model = MMDataParallel(OldStyleModel()) | |
| _ = runner_class(model, batch_processor, logger=logging.getLogger()) | |
| model = MMDataParallel(Model()) | |
| _ = runner_class(model, logger=logging.getLogger()) | |
| with pytest.raises(RuntimeError): | |
| # batch_processor and train_step() cannot be both set | |
| def batch_processor(): | |
| pass | |
| model = MMDataParallel(Model()) | |
| _ = runner_class(model, batch_processor, logger=logging.getLogger()) | |
| def test_save_checkpoint(runner_class): | |
| model = Model() | |
| runner = runner_class(model=model, logger=logging.getLogger()) | |
| with pytest.raises(TypeError): | |
| # meta should be None or dict | |
| runner.save_checkpoint('.', meta=list()) | |
| with tempfile.TemporaryDirectory() as root: | |
| runner.save_checkpoint(root) | |
| latest_path = osp.join(root, 'latest.pth') | |
| assert osp.exists(latest_path) | |
| if isinstance(runner, EpochBasedRunner): | |
| first_ckp_path = osp.join(root, 'epoch_1.pth') | |
| elif isinstance(runner, IterBasedRunner): | |
| first_ckp_path = osp.join(root, 'iter_1.pth') | |
| assert osp.exists(first_ckp_path) | |
| if platform.system() != 'Windows': | |
| assert osp.realpath(latest_path) == osp.realpath(first_ckp_path) | |
| else: | |
| # use copy instead of symlink on windows | |
| pass | |
| torch.load(latest_path) | |
| def test_build_lr_momentum_hook(runner_class): | |
| model = Model() | |
| runner = runner_class(model=model, logger=logging.getLogger()) | |
| # test policy that is already title | |
| lr_config = dict( | |
| policy='CosineAnnealing', | |
| by_epoch=False, | |
| min_lr_ratio=0, | |
| warmup_iters=2, | |
| warmup_ratio=0.9) | |
| runner.register_lr_hook(lr_config) | |
| assert len(runner.hooks) == 1 | |
| # test policy that is already title | |
| lr_config = dict( | |
| policy='Cyclic', | |
| by_epoch=False, | |
| target_ratio=(10, 1), | |
| cyclic_times=1, | |
| step_ratio_up=0.4) | |
| runner.register_lr_hook(lr_config) | |
| assert len(runner.hooks) == 2 | |
| # test policy that is not title | |
| lr_config = dict( | |
| policy='cyclic', | |
| by_epoch=False, | |
| target_ratio=(0.85 / 0.95, 1), | |
| cyclic_times=1, | |
| step_ratio_up=0.4) | |
| runner.register_lr_hook(lr_config) | |
| assert len(runner.hooks) == 3 | |
| # test policy that is title | |
| lr_config = dict( | |
| policy='Step', | |
| warmup='linear', | |
| warmup_iters=500, | |
| warmup_ratio=1.0 / 3, | |
| step=[8, 11]) | |
| runner.register_lr_hook(lr_config) | |
| assert len(runner.hooks) == 4 | |
| # test policy that is not title | |
| lr_config = dict( | |
| policy='step', | |
| warmup='linear', | |
| warmup_iters=500, | |
| warmup_ratio=1.0 / 3, | |
| step=[8, 11]) | |
| runner.register_lr_hook(lr_config) | |
| assert len(runner.hooks) == 5 | |
| # test policy that is already title | |
| mom_config = dict( | |
| policy='CosineAnnealing', | |
| min_momentum_ratio=0.99 / 0.95, | |
| by_epoch=False, | |
| warmup_iters=2, | |
| warmup_ratio=0.9 / 0.95) | |
| runner.register_momentum_hook(mom_config) | |
| assert len(runner.hooks) == 6 | |
| # test policy that is already title | |
| mom_config = dict( | |
| policy='Cyclic', | |
| by_epoch=False, | |
| target_ratio=(0.85 / 0.95, 1), | |
| cyclic_times=1, | |
| step_ratio_up=0.4) | |
| runner.register_momentum_hook(mom_config) | |
| assert len(runner.hooks) == 7 | |
| # test policy that is already title | |
| mom_config = dict( | |
| policy='cyclic', | |
| by_epoch=False, | |
| target_ratio=(0.85 / 0.95, 1), | |
| cyclic_times=1, | |
| step_ratio_up=0.4) | |
| runner.register_momentum_hook(mom_config) | |
| assert len(runner.hooks) == 8 | |
| def test_register_timer_hook(runner_class): | |
| model = Model() | |
| runner = runner_class(model=model, logger=logging.getLogger()) | |
| # test register None | |
| timer_config = None | |
| runner.register_timer_hook(timer_config) | |
| assert len(runner.hooks) == 0 | |
| # test register IterTimerHook with config | |
| timer_config = dict(type='IterTimerHook') | |
| runner.register_timer_hook(timer_config) | |
| assert len(runner.hooks) == 1 | |
| assert isinstance(runner.hooks[0], IterTimerHook) | |
| # test register IterTimerHook | |
| timer_config = IterTimerHook() | |
| runner.register_timer_hook(timer_config) | |
| assert len(runner.hooks) == 2 | |
| assert isinstance(runner.hooks[1], IterTimerHook) | |