| | |
| | |
| |
|
| | import os |
| | import sys |
| | import random |
| | import time |
| | import warnings |
| | from loguru import logger |
| |
|
| | import torch |
| | import torch.backends.cudnn as cudnn |
| |
|
| | from yolox.exp import Exp, get_exp |
| | from yolox.core import Trainer |
| | from yolox.utils import configure_module, configure_omp |
| | from yolox.tools.train import make_parser |
| |
|
| |
|
| | class AssignVisualizer(Trainer): |
| |
|
| | def __init__(self, exp: Exp, args): |
| | super().__init__(exp, args) |
| | self.batch_cnt = 0 |
| | self.vis_dir = os.path.join(self.file_name, "vis") |
| | os.makedirs(self.vis_dir, exist_ok=True) |
| |
|
| | def train_one_iter(self): |
| | iter_start_time = time.time() |
| |
|
| | inps, targets = self.prefetcher.next() |
| | inps = inps.to(self.data_type) |
| | targets = targets.to(self.data_type) |
| | targets.requires_grad = False |
| | inps, targets = self.exp.preprocess(inps, targets, self.input_size) |
| | data_end_time = time.time() |
| |
|
| | with torch.cuda.amp.autocast(enabled=self.amp_training): |
| | path_prefix = os.path.join(self.vis_dir, f"assign_vis_{self.batch_cnt}_") |
| | self.model.visualize(inps, targets, path_prefix) |
| |
|
| | if self.use_model_ema: |
| | self.ema_model.update(self.model) |
| |
|
| | iter_end_time = time.time() |
| | self.meter.update( |
| | iter_time=iter_end_time - iter_start_time, |
| | data_time=data_end_time - iter_start_time, |
| | ) |
| | self.batch_cnt += 1 |
| | if self.batch_cnt >= self.args.max_batch: |
| | sys.exit(0) |
| |
|
| | def after_train(self): |
| | logger.info("Finish visualize assignment, exit...") |
| |
|
| |
|
| | def assign_vis_parser(): |
| | parser = make_parser() |
| | parser.add_argument("--max-batch", type=int, default=1, help="max batch of images to visualize") |
| | return parser |
| |
|
| |
|
| | @logger.catch |
| | def main(exp: Exp, args): |
| | if exp.seed is not None: |
| | random.seed(exp.seed) |
| | torch.manual_seed(exp.seed) |
| | cudnn.deterministic = True |
| | warnings.warn( |
| | "You have chosen to seed training. This will turn on the CUDNN deterministic setting, " |
| | "which can slow down your training considerably! You may see unexpected behavior " |
| | "when restarting from checkpoints." |
| | ) |
| |
|
| | |
| | configure_omp() |
| | cudnn.benchmark = True |
| |
|
| | visualizer = AssignVisualizer(exp, args) |
| | visualizer.train() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | configure_module() |
| | args = assign_vis_parser().parse_args() |
| | exp = get_exp(args.exp_file, args.name) |
| | exp.merge(args.opts) |
| |
|
| | if not args.experiment_name: |
| | args.experiment_name = exp.exp_name |
| |
|
| | main(exp, args) |
| |
|