| |
| |
| """ |
| A script to benchmark builtin models. |
| |
| Note: this script has an extra dependency of psutil. |
| """ |
|
|
| import itertools |
| import logging |
| import psutil |
| import torch |
| import tqdm |
| from fvcore.common.timer import Timer |
| from torch.nn.parallel import DistributedDataParallel |
|
|
| from detectron2.checkpoint import DetectionCheckpointer |
| from detectron2.config import get_cfg |
| from detectron2.data import ( |
| DatasetFromList, |
| build_detection_test_loader, |
| build_detection_train_loader, |
| ) |
| from detectron2.engine import SimpleTrainer, default_argument_parser, hooks, launch |
| from detectron2.modeling import build_model |
| from detectron2.solver import build_optimizer |
| from detectron2.utils import comm |
| from detectron2.utils.events import CommonMetricPrinter |
| from detectron2.utils.logger import setup_logger |
|
|
| logger = logging.getLogger("detectron2") |
|
|
|
|
| def setup(args): |
| cfg = get_cfg() |
| cfg.merge_from_file(args.config_file) |
| cfg.SOLVER.BASE_LR = 0.001 |
| cfg.merge_from_list(args.opts) |
| cfg.freeze() |
| setup_logger(distributed_rank=comm.get_rank()) |
| return cfg |
|
|
|
|
| def benchmark_data(args): |
| cfg = setup(args) |
|
|
| timer = Timer() |
| dataloader = build_detection_train_loader(cfg) |
| logger.info("Initialize loader using {} seconds.".format(timer.seconds())) |
|
|
| timer.reset() |
| itr = iter(dataloader) |
| for i in range(10): |
| next(itr) |
| if i == 0: |
| startup_time = timer.seconds() |
| timer = Timer() |
| max_iter = 1000 |
| for _ in tqdm.trange(max_iter): |
| next(itr) |
| logger.info( |
| "{} iters ({} images) in {} seconds.".format( |
| max_iter, max_iter * cfg.SOLVER.IMS_PER_BATCH, timer.seconds() |
| ) |
| ) |
| logger.info("Startup time: {} seconds".format(startup_time)) |
| vram = psutil.virtual_memory() |
| logger.info( |
| "RAM Usage: {:.2f}/{:.2f} GB".format( |
| (vram.total - vram.available) / 1024 ** 3, vram.total / 1024 ** 3 |
| ) |
| ) |
|
|
| |
| for _ in range(10): |
| timer = Timer() |
| max_iter = 1000 |
| for _ in tqdm.trange(max_iter): |
| next(itr) |
| logger.info( |
| "{} iters ({} images) in {} seconds.".format( |
| max_iter, max_iter * cfg.SOLVER.IMS_PER_BATCH, timer.seconds() |
| ) |
| ) |
|
|
|
|
| def benchmark_train(args): |
| cfg = setup(args) |
| model = build_model(cfg) |
| logger.info("Model:\n{}".format(model)) |
| if comm.get_world_size() > 1: |
| model = DistributedDataParallel( |
| model, device_ids=[comm.get_local_rank()], broadcast_buffers=False |
| ) |
| optimizer = build_optimizer(cfg, model) |
| checkpointer = DetectionCheckpointer(model, optimizer=optimizer) |
| checkpointer.load(cfg.MODEL.WEIGHTS) |
|
|
| cfg.defrost() |
| cfg.DATALOADER.NUM_WORKERS = 0 |
| data_loader = build_detection_train_loader(cfg) |
| dummy_data = list(itertools.islice(data_loader, 100)) |
|
|
| def f(): |
| data = DatasetFromList(dummy_data, copy=False) |
| while True: |
| yield from data |
|
|
| max_iter = 400 |
| trainer = SimpleTrainer(model, f(), optimizer) |
| trainer.register_hooks( |
| [hooks.IterationTimer(), hooks.PeriodicWriter([CommonMetricPrinter(max_iter)])] |
| ) |
| trainer.train(1, max_iter) |
|
|
|
|
| @torch.no_grad() |
| def benchmark_eval(args): |
| cfg = setup(args) |
| model = build_model(cfg) |
| model.eval() |
| logger.info("Model:\n{}".format(model)) |
| DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS) |
|
|
| cfg.defrost() |
| cfg.DATALOADER.NUM_WORKERS = 0 |
| data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0]) |
| dummy_data = list(itertools.islice(data_loader, 100)) |
|
|
| def f(): |
| while True: |
| yield from DatasetFromList(dummy_data, copy=False) |
|
|
| for _ in range(5): |
| model(dummy_data[0]) |
|
|
| max_iter = 400 |
| timer = Timer() |
| with tqdm.tqdm(total=max_iter) as pbar: |
| for idx, d in enumerate(f()): |
| if idx == max_iter: |
| break |
| model(d) |
| pbar.update() |
| logger.info("{} iters in {} seconds.".format(max_iter, timer.seconds())) |
|
|
|
|
| if __name__ == "__main__": |
| parser = default_argument_parser() |
| parser.add_argument("--task", choices=["train", "eval", "data"], required=True) |
| args = parser.parse_args() |
| assert not args.eval_only |
|
|
| if args.task == "data": |
| f = benchmark_data |
| elif args.task == "train": |
| """ |
| Note: training speed may not be representative. |
| The training cost of a R-CNN model varies with the content of the data |
| and the quality of the model. |
| """ |
| f = benchmark_train |
| elif args.task == "eval": |
| f = benchmark_eval |
| |
| assert args.num_gpus == 1 and args.num_machines == 1 |
| launch(f, args.num_gpus, args.num_machines, args.machine_rank, args.dist_url, args=(args,)) |
|
|