| |
| import warnings |
|
|
| import mmcv |
| import numpy as np |
| import torch |
| import torch.distributed as dist |
| from mmcv.parallel import MMDataParallel, MMDistributedDataParallel |
| from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, |
| Fp16OptimizerHook, OptimizerHook, build_optimizer, |
| build_runner, get_dist_info) |
| from mmdet.core import DistEvalHook, EvalHook |
| from mmdet.datasets import build_dataloader, build_dataset |
|
|
| from mmocr import digit_version |
| from mmocr.apis.utils import (disable_text_recog_aug_test, |
| replace_image_to_tensor) |
| from mmocr.utils import get_root_logger |
|
|
|
|
| def train_detector(model, |
| dataset, |
| cfg, |
| distributed=False, |
| validate=False, |
| timestamp=None, |
| meta=None): |
| logger = get_root_logger(cfg.log_level) |
|
|
| |
| dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] |
| |
| loader_cfg = { |
| **dict( |
| seed=cfg.get('seed'), |
| drop_last=False, |
| dist=distributed, |
| num_gpus=len(cfg.gpu_ids)), |
| **({} if torch.__version__ != 'parrots' else dict( |
| prefetch_num=2, |
| pin_memory=False, |
| )), |
| **dict((k, cfg.data[k]) for k in [ |
| 'samples_per_gpu', |
| 'workers_per_gpu', |
| 'shuffle', |
| 'seed', |
| 'drop_last', |
| 'prefetch_num', |
| 'pin_memory', |
| 'persistent_workers', |
| ] if k in cfg.data) |
| } |
|
|
| |
| train_loader_cfg = dict(loader_cfg, **cfg.data.get('train_dataloader', {})) |
|
|
| data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset] |
|
|
| |
| if distributed: |
| find_unused_parameters = cfg.get('find_unused_parameters', False) |
| |
| |
| model = MMDistributedDataParallel( |
| model.cuda(), |
| device_ids=[torch.cuda.current_device()], |
| broadcast_buffers=False, |
| find_unused_parameters=find_unused_parameters) |
| else: |
| if not torch.cuda.is_available(): |
| assert digit_version(mmcv.__version__) >= digit_version('1.4.4'), \ |
| 'Please use MMCV >= 1.4.4 for CPU training!' |
| model = MMDataParallel(model, device_ids=cfg.gpu_ids) |
|
|
| |
| optimizer = build_optimizer(model, cfg.optimizer) |
|
|
| if 'runner' not in cfg: |
| cfg.runner = { |
| 'type': 'EpochBasedRunner', |
| 'max_epochs': cfg.total_epochs |
| } |
| warnings.warn( |
| 'config is now expected to have a `runner` section, ' |
| 'please set `runner` in your config.', UserWarning) |
| else: |
| if 'total_epochs' in cfg: |
| assert cfg.total_epochs == cfg.runner.max_epochs |
|
|
| runner = build_runner( |
| cfg.runner, |
| default_args=dict( |
| model=model, |
| optimizer=optimizer, |
| work_dir=cfg.work_dir, |
| logger=logger, |
| meta=meta)) |
|
|
| |
| runner.timestamp = timestamp |
|
|
| |
| fp16_cfg = cfg.get('fp16', None) |
| if fp16_cfg is not None: |
| optimizer_config = Fp16OptimizerHook( |
| **cfg.optimizer_config, **fp16_cfg, distributed=distributed) |
| elif distributed and 'type' not in cfg.optimizer_config: |
| optimizer_config = OptimizerHook(**cfg.optimizer_config) |
| else: |
| optimizer_config = cfg.optimizer_config |
|
|
| |
| runner.register_training_hooks( |
| cfg.lr_config, |
| optimizer_config, |
| cfg.checkpoint_config, |
| cfg.log_config, |
| cfg.get('momentum_config', None), |
| custom_hooks_config=cfg.get('custom_hooks', None)) |
| if distributed: |
| if isinstance(runner, EpochBasedRunner): |
| runner.register_hook(DistSamplerSeedHook()) |
|
|
| |
| if validate: |
| val_samples_per_gpu = (cfg.data.get('val_dataloader', {})).get( |
| 'samples_per_gpu', cfg.data.get('samples_per_gpu', 1)) |
| if val_samples_per_gpu > 1: |
| |
| |
| cfg = disable_text_recog_aug_test(cfg) |
| cfg = replace_image_to_tensor(cfg) |
|
|
| val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) |
|
|
| val_loader_cfg = { |
| **loader_cfg, |
| **dict(shuffle=False, drop_last=False), |
| **cfg.data.get('val_dataloader', {}), |
| **dict(samples_per_gpu=val_samples_per_gpu) |
| } |
|
|
| val_dataloader = build_dataloader(val_dataset, **val_loader_cfg) |
|
|
| eval_cfg = cfg.get('evaluation', {}) |
| eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' |
| eval_hook = DistEvalHook if distributed else EvalHook |
| runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) |
|
|
| if cfg.resume_from: |
| runner.resume(cfg.resume_from) |
| elif cfg.load_from: |
| runner.load_checkpoint(cfg.load_from) |
| runner.run(data_loaders, cfg.workflow) |
|
|
|
|
| def init_random_seed(seed=None, device='cuda'): |
| """Initialize random seed. If the seed is None, it will be replaced by a |
| random number, and then broadcasted to all processes. |
| |
| Args: |
| seed (int, Optional): The seed. |
| device (str): The device where the seed will be put on. |
| |
| Returns: |
| int: Seed to be used. |
| """ |
| if seed is not None: |
| return seed |
|
|
| |
| |
| |
| rank, world_size = get_dist_info() |
| seed = np.random.randint(2**31) |
| if world_size == 1: |
| return seed |
|
|
| if rank == 0: |
| random_num = torch.tensor(seed, dtype=torch.int32, device=device) |
| else: |
| random_num = torch.tensor(0, dtype=torch.int32, device=device) |
| dist.broadcast(random_num, src=0) |
| return random_num.item() |
|
|