| import random |
| import warnings |
| import os |
| import numpy as np |
| import torch |
| |
| import torch.distributed as dist |
| from mmcv.parallel import MMDataParallel, MMDistributedDataParallel |
| from mmcv.runner import ( |
| HOOKS, |
| DistSamplerSeedHook, |
| EpochBasedRunner, |
| Fp16OptimizerHook, |
| OptimizerHook, |
| build_optimizer, |
| build_runner, |
| get_dist_info, |
| ) |
| from mmcv.utils import build_from_cfg |
|
|
| from mmdet.core import EvalHook, DistEvalHook |
|
|
| from mmdet.datasets import build_dataset, replace_ImageToTensor |
| from mmdet.utils import get_root_logger |
| import time |
| import os.path as osp |
| from mmdet3d_plugin.datasets.builder import build_dataloader |
| from mmdet3d_plugin.datasets.builder import custom_build_dataset |
|
|
| |
| from mmdet3d_plugin.models.runner.epoch_based import EpochBasedRunnerAutoResume |
|
|
|
|
| def custom_train_detector( |
| model, |
| dataset, |
| cfg, |
| distributed=False, |
| validate=False, |
| timestamp=None, |
| eval_model=None, |
| meta=None, |
| ): |
| logger = get_root_logger(cfg.log_level) |
|
|
| |
|
|
| dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] |
| if "imgs_per_gpu" in cfg.data: |
| logger.warning( |
| '"imgs_per_gpu" is deprecated in MMDet V2.0. ' |
| 'Please use "samples_per_gpu" instead' |
| ) |
| if "samples_per_gpu" in cfg.data: |
| logger.warning( |
| f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and ' |
| f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"' |
| f"={cfg.data.imgs_per_gpu} is used in this experiments" |
| ) |
| else: |
| logger.warning( |
| 'Automatically set "samples_per_gpu"="imgs_per_gpu"=' |
| f"{cfg.data.imgs_per_gpu} in this experiments" |
| ) |
| cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu |
|
|
| data_loaders = [ |
| build_dataloader( |
| ds, |
| cfg.data.samples_per_gpu, |
| cfg.data.workers_per_gpu, |
| |
| len(cfg.gpu_ids), |
| dist=distributed, |
| seed=cfg.seed, |
| shuffler_sampler=cfg.data.shuffler_sampler, |
| nonshuffler_sampler=cfg.data.nonshuffler_sampler, |
| ) |
| for ds in dataset |
| ] |
| logger.info("dataloader build done.") |
|
|
| |
| if distributed: |
| find_unused_parameters = cfg.get("find_unused_parameters", False) |
|
|
| |
| |
| model = MMDistributedDataParallel( |
| model.cuda(), |
| device_ids=[torch.cuda.current_device()], |
| broadcast_buffers=False, |
| find_unused_parameters=find_unused_parameters, |
| ) |
| if eval_model is not None: |
| eval_model = MMDistributedDataParallel( |
| eval_model.cuda(), |
| device_ids=[torch.cuda.current_device()], |
| broadcast_buffers=False, |
| find_unused_parameters=find_unused_parameters, |
| ) |
| else: |
| model = MMDataParallel(model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) |
| if eval_model is not None: |
| eval_model = MMDataParallel( |
| eval_model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids |
| ) |
| logger.info("DDP/DP model build done.") |
|
|
| |
| optimizer = build_optimizer(model, cfg.optimizer) |
|
|
| if "runner" not in cfg: |
| cfg.runner = {"type": "EpochBasedRunnerAutoResume", "max_epochs": cfg.total_epochs} |
| warnings.warn( |
| "config is now expected to have a `runner` section, " |
| "please set `runner` in your config.", |
| UserWarning, |
| ) |
| else: |
| if "total_epochs" in cfg: |
| logger.info(f"cfg.total_epochs: {cfg.total_epochs}, cfg.runner.max_epochs: {cfg.runner.max_epochs}") |
| assert cfg.total_epochs == cfg.runner.max_epochs |
| if eval_model is not None: |
| runner = build_runner( |
| cfg.runner, |
| default_args=dict( |
| model=model, |
| eval_model=eval_model, |
| optimizer=optimizer, |
| work_dir=cfg.work_dir, |
| logger=logger, |
| meta=meta, |
| ), |
| ) |
| else: |
| runner = build_runner( |
| cfg.runner, |
| default_args=dict( |
| model=model, |
| optimizer=optimizer, |
| work_dir=cfg.work_dir, |
| logger=logger, |
| meta=meta, |
| ), |
| ) |
|
|
| |
| runner.timestamp = timestamp |
|
|
| |
| fp16_cfg = cfg.get("fp16", None) |
| if fp16_cfg is not None: |
| optimizer_config = Fp16OptimizerHook( |
| **cfg.optimizer_config, **fp16_cfg, distributed=distributed |
| ) |
| elif distributed and "type" not in cfg.optimizer_config: |
| optimizer_config = OptimizerHook(**cfg.optimizer_config) |
| else: |
| optimizer_config = cfg.optimizer_config |
|
|
| |
| runner.register_training_hooks( |
| cfg.lr_config, |
| optimizer_config, |
| cfg.checkpoint_config, |
| cfg.log_config, |
| cfg.get("momentum_config", None), |
| ) |
|
|
| |
| |
| |
| |
|
|
| if distributed: |
| if isinstance(runner, EpochBasedRunnerAutoResume): |
| runner.register_hook(DistSamplerSeedHook()) |
|
|
| |
| if validate: |
| |
| val_samples_per_gpu = cfg.data.val.pop("samples_per_gpu", 1) |
| if val_samples_per_gpu > 1: |
| assert False |
| |
| cfg.data.val.pipeline = replace_ImageToTensor(cfg.data.val.pipeline) |
| val_dataset = custom_build_dataset(cfg.data.val, dict(test_mode=True)) |
|
|
| val_dataloader = build_dataloader( |
| val_dataset, |
| samples_per_gpu=val_samples_per_gpu, |
| workers_per_gpu=cfg.data.workers_per_gpu, |
| dist=distributed, |
| shuffle=False, |
| shuffler_sampler=cfg.data.shuffler_sampler, |
| nonshuffler_sampler=cfg.data.nonshuffler_sampler, |
| ) |
| eval_cfg = cfg.get("evaluation", {}) |
| eval_cfg["by_epoch"] = cfg.runner["type"] != "IterBasedRunner" |
|
|
| |
| |
| |
|
|
| |
| eval_cfg["jsonfile_prefix"] = osp.join( |
| cfg.work_dir, "val", time.ctime().replace(" ", "_").replace(":", "_") |
| ) |
| eval_hook = DistEvalHook if distributed else EvalHook |
| runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) |
|
|
| |
| if cfg.get("custom_hooks", None): |
| custom_hooks = cfg.custom_hooks |
| assert isinstance( |
| custom_hooks, list |
| ), f"custom_hooks expect list type, but got {type(custom_hooks)}" |
| for hook_cfg in cfg.custom_hooks: |
| assert isinstance(hook_cfg, dict), ( |
| "Each item in custom_hooks expects dict type, but got " |
| f"{type(hook_cfg)}" |
| ) |
| hook_cfg = hook_cfg.copy() |
|
|
| |
| hook_cfg['out_dir'] = cfg.work_dir |
| |
| |
|
|
| priority = hook_cfg.pop("priority", "NORMAL") |
| hook = build_from_cfg(hook_cfg, HOOKS) |
| runner.register_hook(hook, priority=priority) |
|
|
| if cfg.resume_from and os.path.exists(cfg.resume_from): |
| logger.info(f'resume_from: {cfg.resume_from}') |
| runner.resume(cfg.resume_from) |
|
|
| elif cfg.load_from: |
| logger.info(f'load from {cfg.load_from}') |
| runner.load_checkpoint(cfg.load_from) |
|
|
| runner.run(data_loaders, cfg.workflow) |
|
|