| | |
| | |
| |
|
| | """ |
| | This file contains components with some default boilerplate logic user may need |
| | in training / testing. They will not work for everyone, but many users may find them useful. |
| | |
| | The behavior of functions/classes in this file is subject to change, |
| | since they are meant to represent the "common default behavior" people need in their projects. |
| | """ |
| |
|
| | import argparse |
| | import logging |
| | import os |
| | import sys |
| | import weakref |
| | from collections import OrderedDict |
| | from typing import Optional |
| | import torch |
| | from fvcore.nn.precise_bn import get_bn_modules |
| | from omegaconf import OmegaConf |
| | from torch.nn.parallel import DistributedDataParallel |
| |
|
| | import detectron2.data.transforms as T |
| | from detectron2.checkpoint import DetectionCheckpointer |
| | from detectron2.config import CfgNode, LazyConfig |
| | from detectron2.data import ( |
| | MetadataCatalog, |
| | build_detection_test_loader, |
| | build_detection_train_loader, |
| | ) |
| | from detectron2.evaluation import ( |
| | DatasetEvaluator, |
| | inference_on_dataset, |
| | print_csv_format, |
| | verify_results, |
| | ) |
| | from detectron2.modeling import build_model |
| | from detectron2.solver import build_lr_scheduler, build_optimizer |
| | from detectron2.utils import comm |
| | from detectron2.utils.collect_env import collect_env_info |
| | from detectron2.utils.env import seed_all_rng |
| | from detectron2.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter |
| | from detectron2.utils.file_io import PathManager |
| | from detectron2.utils.logger import setup_logger |
| |
|
| | from . import hooks |
| | from .train_loop import AMPTrainer, SimpleTrainer, TrainerBase |
| |
|
| | __all__ = [ |
| | "create_ddp_model", |
| | "default_argument_parser", |
| | "default_setup", |
| | "default_writers", |
| | "DefaultPredictor", |
| | "DefaultTrainer", |
| | ] |
| |
|
| |
|
| | def create_ddp_model(model, *, fp16_compression=False, **kwargs): |
| | """ |
| | Create a DistributedDataParallel model if there are >1 processes. |
| | |
| | Args: |
| | model: a torch.nn.Module |
| | fp16_compression: add fp16 compression hooks to the ddp object. |
| | See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook |
| | kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`. |
| | """ |
| | if comm.get_world_size() == 1: |
| | return model |
| | if "device_ids" not in kwargs: |
| | kwargs["device_ids"] = [comm.get_local_rank()] |
| | ddp = DistributedDataParallel(model, **kwargs) |
| | if fp16_compression: |
| | from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks |
| |
|
| | ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook) |
| | return ddp |
| |
|
| |
|
| | def default_argument_parser(epilog=None): |
| | """ |
| | Create a parser with some common arguments used by detectron2 users. |
| | |
| | Args: |
| | epilog (str): epilog passed to ArgumentParser describing the usage. |
| | |
| | Returns: |
| | argparse.ArgumentParser: |
| | """ |
| | parser = argparse.ArgumentParser( |
| | epilog=epilog |
| | or f""" |
| | Examples: |
| | |
| | Run on single machine: |
| | $ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml |
| | |
| | Change some config options: |
| | $ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001 |
| | |
| | Run on multiple machines: |
| | (machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url <URL> [--other-flags] |
| | (machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url <URL> [--other-flags] |
| | """, |
| | formatter_class=argparse.RawDescriptionHelpFormatter, |
| | ) |
| | parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") |
| | parser.add_argument( |
| | "--resume", |
| | action="store_true", |
| | help="Whether to attempt to resume from the checkpoint directory. " |
| | "See documentation of `DefaultTrainer.resume_or_load()` for what it means.", |
| | ) |
| | parser.add_argument("--eval-only", action="store_true", help="perform evaluation only") |
| | parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*") |
| | parser.add_argument("--num-machines", type=int, default=1, help="total number of machines") |
| | parser.add_argument( |
| | "--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)" |
| | ) |
| |
|
| | |
| | |
| | |
| | port = 2**15 + 2**14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2**14 |
| | parser.add_argument( |
| | "--dist-url", |
| | default="tcp://127.0.0.1:{}".format(port), |
| | help="initialization URL for pytorch distributed backend. See " |
| | "https://pytorch.org/docs/stable/distributed.html for details.", |
| | ) |
| | parser.add_argument( |
| | "opts", |
| | help=""" |
| | Modify config options at the end of the command. For Yacs configs, use |
| | space-separated "PATH.KEY VALUE" pairs. |
| | For python-based LazyConfig, use "path.key=value". |
| | """.strip(), |
| | default=None, |
| | nargs=argparse.REMAINDER, |
| | ) |
| | return parser |
| |
|
| |
|
| | def _try_get_key(cfg, *keys, default=None): |
| | """ |
| | Try select keys from cfg until the first key that exists. Otherwise return default. |
| | """ |
| | if isinstance(cfg, CfgNode): |
| | cfg = OmegaConf.create(cfg.dump()) |
| | for k in keys: |
| | none = object() |
| | p = OmegaConf.select(cfg, k, default=none) |
| | if p is not none: |
| | return p |
| | return default |
| |
|
| |
|
| | def _highlight(code, filename): |
| | try: |
| | import pygments |
| | except ImportError: |
| | return code |
| |
|
| | from pygments.lexers import Python3Lexer, YamlLexer |
| | from pygments.formatters import Terminal256Formatter |
| |
|
| | lexer = Python3Lexer() if filename.endswith(".py") else YamlLexer() |
| | code = pygments.highlight(code, lexer, Terminal256Formatter(style="monokai")) |
| | return code |
| |
|
| |
|
| | def default_setup(cfg, args): |
| | """ |
| | Perform some basic common setups at the beginning of a job, including: |
| | |
| | 1. Set up the detectron2 logger |
| | 2. Log basic information about environment, cmdline arguments, and config |
| | 3. Backup the config to the output directory |
| | |
| | Args: |
| | cfg (CfgNode or omegaconf.DictConfig): the full config to be used |
| | args (argparse.NameSpace): the command line arguments to be logged |
| | """ |
| | output_dir = _try_get_key(cfg, "OUTPUT_DIR", "output_dir", "train.output_dir") |
| | if comm.is_main_process() and output_dir: |
| | PathManager.mkdirs(output_dir) |
| |
|
| | rank = comm.get_rank() |
| | setup_logger(output_dir, distributed_rank=rank, name="fvcore") |
| | logger = setup_logger(output_dir, distributed_rank=rank) |
| |
|
| | logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size())) |
| | logger.info("Environment info:\n" + collect_env_info()) |
| |
|
| | logger.info("Command line arguments: " + str(args)) |
| | if hasattr(args, "config_file") and args.config_file != "": |
| | logger.info( |
| | "Contents of args.config_file={}:\n{}".format( |
| | args.config_file, |
| | _highlight(PathManager.open(args.config_file, "r").read(), args.config_file), |
| | ) |
| | ) |
| |
|
| | if comm.is_main_process() and output_dir: |
| | |
| | |
| | path = os.path.join(output_dir, "config.yaml") |
| | if isinstance(cfg, CfgNode): |
| | logger.info("Running with full config:\n{}".format(_highlight(cfg.dump(), ".yaml"))) |
| | with PathManager.open(path, "w") as f: |
| | f.write(cfg.dump()) |
| | else: |
| | LazyConfig.save(cfg, path) |
| | logger.info("Full config saved to {}".format(path)) |
| |
|
| | |
| | seed = _try_get_key(cfg, "SEED", "train.seed", default=-1) |
| | seed_all_rng(None if seed < 0 else seed + rank) |
| |
|
| | |
| | |
| | if not (hasattr(args, "eval_only") and args.eval_only): |
| | torch.backends.cudnn.benchmark = _try_get_key( |
| | cfg, "CUDNN_BENCHMARK", "train.cudnn_benchmark", default=False |
| | ) |
| |
|
| |
|
| | def default_writers(output_dir: str, max_iter: Optional[int] = None): |
| | """ |
| | Build a list of :class:`EventWriter` to be used. |
| | It now consists of a :class:`CommonMetricPrinter`, |
| | :class:`TensorboardXWriter` and :class:`JSONWriter`. |
| | |
| | Args: |
| | output_dir: directory to store JSON metrics and tensorboard events |
| | max_iter: the total number of iterations |
| | |
| | Returns: |
| | list[EventWriter]: a list of :class:`EventWriter` objects. |
| | """ |
| | PathManager.mkdirs(output_dir) |
| | return [ |
| | |
| | CommonMetricPrinter(max_iter), |
| | JSONWriter(os.path.join(output_dir, "metrics.json")), |
| | TensorboardXWriter(output_dir), |
| | ] |
| |
|
| |
|
| | class DefaultPredictor: |
| | """ |
| | Create a simple end-to-end predictor with the given config that runs on |
| | single device for a single input image. |
| | |
| | Compared to using the model directly, this class does the following additions: |
| | |
| | 1. Load checkpoint from `cfg.MODEL.WEIGHTS`. |
| | 2. Always take BGR image as the input and apply conversion defined by `cfg.INPUT.FORMAT`. |
| | 3. Apply resizing defined by `cfg.INPUT.{MIN,MAX}_SIZE_TEST`. |
| | 4. Take one input image and produce a single output, instead of a batch. |
| | |
| | This is meant for simple demo purposes, so it does the above steps automatically. |
| | This is not meant for benchmarks or running complicated inference logic. |
| | If you'd like to do anything more complicated, please refer to its source code as |
| | examples to build and use the model manually. |
| | |
| | Attributes: |
| | metadata (Metadata): the metadata of the underlying dataset, obtained from |
| | cfg.DATASETS.TEST. |
| | |
| | Examples: |
| | :: |
| | pred = DefaultPredictor(cfg) |
| | inputs = cv2.imread("input.jpg") |
| | outputs = pred(inputs) |
| | """ |
| |
|
| | def __init__(self, cfg): |
| | self.cfg = cfg.clone() |
| | self.model = build_model(self.cfg) |
| | self.model.eval() |
| | if len(cfg.DATASETS.TEST): |
| | self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0]) |
| |
|
| | checkpointer = DetectionCheckpointer(self.model) |
| | checkpointer.load(cfg.MODEL.WEIGHTS) |
| |
|
| | self.aug = T.ResizeShortestEdge( |
| | [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST |
| | ) |
| |
|
| | self.input_format = cfg.INPUT.FORMAT |
| | assert self.input_format in ["RGB", "BGR"], self.input_format |
| |
|
| | def __call__(self, original_image): |
| | """ |
| | Args: |
| | original_image (np.ndarray): an image of shape (H, W, C) (in BGR order). |
| | |
| | Returns: |
| | predictions (dict): |
| | the output of the model for one image only. |
| | See :doc:`/tutorials/models` for details about the format. |
| | """ |
| | with torch.no_grad(): |
| | |
| | if self.input_format == "RGB": |
| | |
| | original_image = original_image[:, :, ::-1] |
| | height, width = original_image.shape[:2] |
| | image = self.aug.get_transform(original_image).apply_image(original_image) |
| | image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) |
| |
|
| | inputs = {"image": image, "height": height, "width": width} |
| | predictions = self.model([inputs])[0] |
| | return predictions |
| |
|
| |
|
| | class DefaultTrainer(TrainerBase): |
| | """ |
| | A trainer with default training logic. It does the following: |
| | |
| | 1. Create a :class:`SimpleTrainer` using model, optimizer, dataloader |
| | defined by the given config. Create a LR scheduler defined by the config. |
| | 2. Load the last checkpoint or `cfg.MODEL.WEIGHTS`, if exists, when |
| | `resume_or_load` is called. |
| | 3. Register a few common hooks defined by the config. |
| | |
| | It is created to simplify the **standard model training workflow** and reduce code boilerplate |
| | for users who only need the standard training workflow, with standard features. |
| | It means this class makes *many assumptions* about your training logic that |
| | may easily become invalid in a new research. In fact, any assumptions beyond those made in the |
| | :class:`SimpleTrainer` are too much for research. |
| | |
| | The code of this class has been annotated about restrictive assumptions it makes. |
| | When they do not work for you, you're encouraged to: |
| | |
| | 1. Overwrite methods of this class, OR: |
| | 2. Use :class:`SimpleTrainer`, which only does minimal SGD training and |
| | nothing else. You can then add your own hooks if needed. OR: |
| | 3. Write your own training loop similar to `tools/plain_train_net.py`. |
| | |
| | See the :doc:`/tutorials/training` tutorials for more details. |
| | |
| | Note that the behavior of this class, like other functions/classes in |
| | this file, is not stable, since it is meant to represent the "common default behavior". |
| | It is only guaranteed to work well with the standard models and training workflow in detectron2. |
| | To obtain more stable behavior, write your own training logic with other public APIs. |
| | |
| | Examples: |
| | :: |
| | trainer = DefaultTrainer(cfg) |
| | trainer.resume_or_load() # load last checkpoint or MODEL.WEIGHTS |
| | trainer.train() |
| | |
| | Attributes: |
| | scheduler: |
| | checkpointer (DetectionCheckpointer): |
| | cfg (CfgNode): |
| | """ |
| |
|
| | def __init__(self, cfg): |
| | """ |
| | Args: |
| | cfg (CfgNode): |
| | """ |
| | super().__init__() |
| | logger = logging.getLogger("detectron2") |
| | if not logger.isEnabledFor(logging.INFO): |
| | setup_logger() |
| | cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size()) |
| |
|
| | |
| | model = self.build_model(cfg) |
| | optimizer = self.build_optimizer(cfg, model) |
| | data_loader = self.build_train_loader(cfg) |
| |
|
| | model = create_ddp_model(model, broadcast_buffers=False) |
| | self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)( |
| | model, data_loader, optimizer |
| | ) |
| |
|
| | self.scheduler = self.build_lr_scheduler(cfg, optimizer) |
| | self.checkpointer = DetectionCheckpointer( |
| | |
| | model, |
| | cfg.OUTPUT_DIR, |
| | trainer=weakref.proxy(self), |
| | ) |
| | self.start_iter = 0 |
| | self.max_iter = cfg.SOLVER.MAX_ITER |
| | self.cfg = cfg |
| |
|
| | self.register_hooks(self.build_hooks()) |
| |
|
| | def resume_or_load(self, resume=True): |
| | """ |
| | If `resume==True` and `cfg.OUTPUT_DIR` contains the last checkpoint (defined by |
| | a `last_checkpoint` file), resume from the file. Resuming means loading all |
| | available states (eg. optimizer and scheduler) and update iteration counter |
| | from the checkpoint. ``cfg.MODEL.WEIGHTS`` will not be used. |
| | |
| | Otherwise, this is considered as an independent training. The method will load model |
| | weights from the file `cfg.MODEL.WEIGHTS` (but will not load other states) and start |
| | from iteration 0. |
| | |
| | Args: |
| | resume (bool): whether to do resume or not |
| | """ |
| | self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume) |
| | if resume and self.checkpointer.has_checkpoint(): |
| | |
| | |
| | self.start_iter = self.iter + 1 |
| |
|
| | def build_hooks(self): |
| | """ |
| | Build a list of default hooks, including timing, evaluation, |
| | checkpointing, lr scheduling, precise BN, writing events. |
| | |
| | Returns: |
| | list[HookBase]: |
| | """ |
| | cfg = self.cfg.clone() |
| | cfg.defrost() |
| | cfg.DATALOADER.NUM_WORKERS = 0 |
| |
|
| | ret = [ |
| | hooks.IterationTimer(), |
| | hooks.LRScheduler(), |
| | hooks.PreciseBN( |
| | |
| | cfg.TEST.EVAL_PERIOD, |
| | self.model, |
| | |
| | self.build_train_loader(cfg), |
| | cfg.TEST.PRECISE_BN.NUM_ITER, |
| | ) |
| | if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model) |
| | else None, |
| | ] |
| |
|
| | |
| | |
| | |
| | |
| | if comm.is_main_process(): |
| | ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD)) |
| |
|
| | def test_and_save_results(): |
| | self._last_eval_results = self.test(self.cfg, self.model) |
| | return self._last_eval_results |
| |
|
| | |
| | |
| | ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results)) |
| |
|
| | if comm.is_main_process(): |
| | |
| | |
| | ret.append(hooks.PeriodicWriter(self.build_writers(), period=20)) |
| | return ret |
| |
|
| | def build_writers(self): |
| | """ |
| | Build a list of writers to be used using :func:`default_writers()`. |
| | If you'd like a different list of writers, you can overwrite it in |
| | your trainer. |
| | |
| | Returns: |
| | list[EventWriter]: a list of :class:`EventWriter` objects. |
| | """ |
| | return default_writers(self.cfg.OUTPUT_DIR, self.max_iter) |
| |
|
| | def train(self): |
| | """ |
| | Run training. |
| | |
| | Returns: |
| | OrderedDict of results, if evaluation is enabled. Otherwise None. |
| | """ |
| | super().train(self.start_iter, self.max_iter) |
| | if len(self.cfg.TEST.EXPECTED_RESULTS) and comm.is_main_process(): |
| | assert hasattr( |
| | self, "_last_eval_results" |
| | ), "No evaluation results obtained during training!" |
| | verify_results(self.cfg, self._last_eval_results) |
| | return self._last_eval_results |
| |
|
| | def run_step(self): |
| | self._trainer.iter = self.iter |
| | self._trainer.run_step() |
| |
|
| | def state_dict(self): |
| | ret = super().state_dict() |
| | ret["_trainer"] = self._trainer.state_dict() |
| | return ret |
| |
|
| | def load_state_dict(self, state_dict): |
| | super().load_state_dict(state_dict) |
| | self._trainer.load_state_dict(state_dict["_trainer"]) |
| |
|
| | @classmethod |
| | def build_model(cls, cfg): |
| | """ |
| | Returns: |
| | torch.nn.Module: |
| | |
| | It now calls :func:`detectron2.modeling.build_model`. |
| | Overwrite it if you'd like a different model. |
| | """ |
| | model = build_model(cfg) |
| | logger = logging.getLogger(__name__) |
| | logger.info("Model:\n{}".format(model)) |
| | return model |
| |
|
| | @classmethod |
| | def build_optimizer(cls, cfg, model): |
| | """ |
| | Returns: |
| | torch.optim.Optimizer: |
| | |
| | It now calls :func:`detectron2.solver.build_optimizer`. |
| | Overwrite it if you'd like a different optimizer. |
| | """ |
| | return build_optimizer(cfg, model) |
| |
|
| | @classmethod |
| | def build_lr_scheduler(cls, cfg, optimizer): |
| | """ |
| | It now calls :func:`detectron2.solver.build_lr_scheduler`. |
| | Overwrite it if you'd like a different scheduler. |
| | """ |
| | return build_lr_scheduler(cfg, optimizer) |
| |
|
| | @classmethod |
| | def build_train_loader(cls, cfg): |
| | """ |
| | Returns: |
| | iterable |
| | |
| | It now calls :func:`detectron2.data.build_detection_train_loader`. |
| | Overwrite it if you'd like a different data loader. |
| | """ |
| | return build_detection_train_loader(cfg) |
| |
|
| | @classmethod |
| | def build_test_loader(cls, cfg, dataset_name): |
| | """ |
| | Returns: |
| | iterable |
| | |
| | It now calls :func:`detectron2.data.build_detection_test_loader`. |
| | Overwrite it if you'd like a different data loader. |
| | """ |
| | return build_detection_test_loader(cfg, dataset_name) |
| |
|
| | @classmethod |
| | def build_evaluator(cls, cfg, dataset_name): |
| | """ |
| | Returns: |
| | DatasetEvaluator or None |
| | |
| | It is not implemented by default. |
| | """ |
| | raise NotImplementedError( |
| | """ |
| | If you want DefaultTrainer to automatically run evaluation, |
| | please implement `build_evaluator()` in subclasses (see train_net.py for example). |
| | Alternatively, you can call evaluation functions yourself (see Colab balloon tutorial for example). |
| | """ |
| | ) |
| |
|
| | @classmethod |
| | def test(cls, cfg, model, evaluators=None): |
| | """ |
| | Evaluate the given model. The given model is expected to already contain |
| | weights to evaluate. |
| | |
| | Args: |
| | cfg (CfgNode): |
| | model (nn.Module): |
| | evaluators (list[DatasetEvaluator] or None): if None, will call |
| | :meth:`build_evaluator`. Otherwise, must have the same length as |
| | ``cfg.DATASETS.TEST``. |
| | |
| | Returns: |
| | dict: a dict of result metrics |
| | """ |
| | logger = logging.getLogger(__name__) |
| | if isinstance(evaluators, DatasetEvaluator): |
| | evaluators = [evaluators] |
| | if evaluators is not None: |
| | assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format( |
| | len(cfg.DATASETS.TEST), len(evaluators) |
| | ) |
| |
|
| | results = OrderedDict() |
| | for idx, dataset_name in enumerate(cfg.DATASETS.TEST): |
| | data_loader = cls.build_test_loader(cfg, dataset_name) |
| | |
| | |
| | if evaluators is not None: |
| | evaluator = evaluators[idx] |
| | else: |
| | try: |
| | evaluator = cls.build_evaluator(cfg, dataset_name) |
| | except NotImplementedError: |
| | logger.warn( |
| | "No evaluator found. Use `DefaultTrainer.test(evaluators=)`, " |
| | "or implement its `build_evaluator` method." |
| | ) |
| | results[dataset_name] = {} |
| | continue |
| | results_i = inference_on_dataset(model, data_loader, evaluator) |
| | results[dataset_name] = results_i |
| | if comm.is_main_process(): |
| | assert isinstance( |
| | results_i, dict |
| | ), "Evaluator must return a dict on the main process. Got {} instead.".format( |
| | results_i |
| | ) |
| | logger.info("Evaluation results for {} in csv format:".format(dataset_name)) |
| | print_csv_format(results_i) |
| |
|
| | if len(results) == 1: |
| | results = list(results.values())[0] |
| | return results |
| |
|
| | @staticmethod |
| | def auto_scale_workers(cfg, num_workers: int): |
| | """ |
| | When the config is defined for certain number of workers (according to |
| | ``cfg.SOLVER.REFERENCE_WORLD_SIZE``) that's different from the number of |
| | workers currently in use, returns a new cfg where the total batch size |
| | is scaled so that the per-GPU batch size stays the same as the |
| | original ``IMS_PER_BATCH // REFERENCE_WORLD_SIZE``. |
| | |
| | Other config options are also scaled accordingly: |
| | * training steps and warmup steps are scaled inverse proportionally. |
| | * learning rate are scaled proportionally, following :paper:`ImageNet in 1h`. |
| | |
| | For example, with the original config like the following: |
| | |
| | .. code-block:: yaml |
| | |
| | IMS_PER_BATCH: 16 |
| | BASE_LR: 0.1 |
| | REFERENCE_WORLD_SIZE: 8 |
| | MAX_ITER: 5000 |
| | STEPS: (4000,) |
| | CHECKPOINT_PERIOD: 1000 |
| | |
| | When this config is used on 16 GPUs instead of the reference number 8, |
| | calling this method will return a new config with: |
| | |
| | .. code-block:: yaml |
| | |
| | IMS_PER_BATCH: 32 |
| | BASE_LR: 0.2 |
| | REFERENCE_WORLD_SIZE: 16 |
| | MAX_ITER: 2500 |
| | STEPS: (2000,) |
| | CHECKPOINT_PERIOD: 500 |
| | |
| | Note that both the original config and this new config can be trained on 16 GPUs. |
| | It's up to user whether to enable this feature (by setting ``REFERENCE_WORLD_SIZE``). |
| | |
| | Returns: |
| | CfgNode: a new config. Same as original if ``cfg.SOLVER.REFERENCE_WORLD_SIZE==0``. |
| | """ |
| | old_world_size = cfg.SOLVER.REFERENCE_WORLD_SIZE |
| | if old_world_size == 0 or old_world_size == num_workers: |
| | return cfg |
| | cfg = cfg.clone() |
| | frozen = cfg.is_frozen() |
| | cfg.defrost() |
| |
|
| | assert ( |
| | cfg.SOLVER.IMS_PER_BATCH % old_world_size == 0 |
| | ), "Invalid REFERENCE_WORLD_SIZE in config!" |
| | scale = num_workers / old_world_size |
| | bs = cfg.SOLVER.IMS_PER_BATCH = int(round(cfg.SOLVER.IMS_PER_BATCH * scale)) |
| | lr = cfg.SOLVER.BASE_LR = cfg.SOLVER.BASE_LR * scale |
| | max_iter = cfg.SOLVER.MAX_ITER = int(round(cfg.SOLVER.MAX_ITER / scale)) |
| | warmup_iter = cfg.SOLVER.WARMUP_ITERS = int(round(cfg.SOLVER.WARMUP_ITERS / scale)) |
| | cfg.SOLVER.STEPS = tuple(int(round(s / scale)) for s in cfg.SOLVER.STEPS) |
| | cfg.TEST.EVAL_PERIOD = int(round(cfg.TEST.EVAL_PERIOD / scale)) |
| | cfg.SOLVER.CHECKPOINT_PERIOD = int(round(cfg.SOLVER.CHECKPOINT_PERIOD / scale)) |
| | cfg.SOLVER.REFERENCE_WORLD_SIZE = num_workers |
| | logger = logging.getLogger(__name__) |
| | logger.info( |
| | f"Auto-scaling the config to batch_size={bs}, learning_rate={lr}, " |
| | f"max_iter={max_iter}, warmup={warmup_iter}." |
| | ) |
| |
|
| | if frozen: |
| | cfg.freeze() |
| | return cfg |
| |
|
| |
|
| | |
| | for _attr in ["model", "data_loader", "optimizer"]: |
| | setattr( |
| | DefaultTrainer, |
| | _attr, |
| | property( |
| | |
| | lambda self, x=_attr: getattr(self._trainer, x), |
| | |
| | lambda self, value, x=_attr: setattr(self._trainer, x, value), |
| | ), |
| | ) |
| |
|