# /*--------------------------------------------------------------------------------------------- #  * Copyright (c) 2025 STMicroelectronics. #  * All rights reserved. #  * #  * This software is licensed under terms that can be found in the LICENSE file in #  * the root directory of this software component. #  * If no LICENSE file comes with this software, it is provided AS-IS. #  *--------------------------------------------------------------------------------------------*/ import logging import os import sys import time from collections import OrderedDict from contextlib import suppress from datetime import datetime from functools import partial from types import SimpleNamespace import torch import torch.nn as nn import torchvision.utils import yaml from timm import utils from timm.data import (FastCollateMixup, Mixup, resolve_data_config) from timm.layers import (convert_splitbn_model, convert_sync_batchnorm, set_fast_norm) from timm.loss import (BinaryCrossEntropy, JsdCrossEntropy, LabelSmoothingCrossEntropy, SoftTargetCrossEntropy) from timm.models import (load_checkpoint, model_parameters, resume_checkpoint, safe_model_name) from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.scheduler import create_scheduler_v2, scheduler_kwargs from timm.utils import ApexScaler, NativeScaler from torch.nn.parallel import DistributedDataParallel as NativeDDP from torch.utils.tensorboard import SummaryWriter SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.append(os.path.dirname(SCRIPT_DIR)) from common.model_utils.checkpoint_saver import CheckpointSaver from common.onnx_utils.onnx_model_convertor import torch_model_export_static try: from apex import amp from apex.parallel import DistributedDataParallel as ApexDDP from apex.parallel import convert_syncbn_model has_apex = True except ImportError: has_apex = False has_native_amp = False try: if getattr(torch.cuda.amp, 'autocast') is not None: has_native_amp = True except AttributeError: pass try: import wandb has_wandb = True except ImportError: has_wandb = False try: from functorch.compile import memory_efficient_fusion has_functorch = True except ImportError as e: has_functorch = False has_compile = hasattr(torch, 'compile') _logger = logging.getLogger('train') class ICTrainer: """ Image Classification Trainer. This class encapsulates the training and validation logic for image classification models. It handles model initialization, data loading, training loop management, evaluation, and logging based on a provided configuration. Attributes: args (Namespace or dict): Configuration object containing training hyperparameters such as learning rate, number of epochs, optimizer type, checkpoint paths, logging options, and device selection. model (torch.nn.Module): The neural network model to train and evaluate. dataloaders (dict): A dictionary containing 'train' and 'val' DataLoader objects for iterating over the dataset splits. Typical usage example: trainer = ICTrainer(args, model, dataloaders) trainer.train() trainer.evaluate() """ def __init__(self, cfg, model, dataloaders): self.cfg = cfg from types import SimpleNamespace from common.utils import flatten_config # convert hydra heirarchical config to flat config args_raw = flatten_config(cfg) # convert dict based config to argument parser type if isinstance(args_raw, dict): args = SimpleNamespace(**args_raw) else: args = args_raw self.args = args self.args.input_size = self.args.input_shape self.model = model self.dataloaders = dataloaders # placeholders for components created later self.device = None self.use_amp = None self.amp_dtype = torch.float16 self.amp_autocast = suppress self.loss_scaler = None self.optimizer = None self.model_ema = None self.loader_train = None self.loader_eval = None self.dataset_train = None self.train_loss_fn = None self.validate_loss_fn = None self.saver = None self.output_dir = None self.writer = None self.lr_scheduler = None self.num_epochs = None self.start_epoch = 0 self.resume_epoch = None self.best_metric = None self.best_epoch = None self.mixup_fn = None self.collate_fn = None #-------------------- Train (main function)----------------- def train(self): self.setup_environment() self.process_model() self.create_optimizer() self.setup_amp_and_scaler() self.resume_if_needed() self.setup_model_ema() self.setup_distributed_and_compile() self.process_dataloaders() self.setup_losses() self.setup_checkpoint_and_logging() self.setup_lr_scheduler_and_start() self.train_loop() onnx_model = torch_model_export_static(cfg=self.cfg, model_dir=self.output_dir, model=self.model.to("cpu")) return onnx_model # ------------------ Environment & Device ------------------ def setup_environment(self): utils.setup_default_logging() args = self.args # copy by reference # cuda / cudnn settings if torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True # boolean flags and simple post-processing expected by downstream code args.prefetcher = not getattr(args, 'no_prefetcher', False) args.grad_accum_steps = max(1, getattr(args, 'grad_accum_steps', 1)) self.device = args.device if args.distributed: _logger.info( 'Training in distributed mode with multiple processes, 1 device per process.' f'Process {args.rank}, total {args.world_size}, device {args.device}.') else: _logger.info(f'Training with a single process on 1 device ({args.device}).') assert args.rank >= 0 # resolve AMP arguments based on PyTorch / Apex availability self.use_amp = None self.amp_dtype = torch.float16 if getattr(args, 'amp', False): if getattr(args, 'amp_impl', 'native') == 'apex': assert has_apex, 'AMP impl specified as APEX but APEX is not installed.' self.use_amp = 'apex' assert args.amp_dtype == 'float16' else: assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).' self.use_amp = 'native' assert args.amp_dtype in ('float16', 'bfloat16') if args.amp_dtype == 'bfloat16': self.amp_dtype = torch.bfloat16 utils.random_seed(getattr(args, 'seed', 42), args.rank) if getattr(args, 'fuser', None): utils.set_jit_fuser(args.fuser) if getattr(args, 'fast_norm', False): set_fast_norm() # ------------------ Model creation ------------------ def process_model(self): args = self.args # optional head init if hasattr(self.model, 'get_classifier'): if getattr(args, 'head_init_scale', None) is not None: with torch.no_grad(): self.model.get_classifier().weight.mul_(args.head_init_scale) self.model.get_classifier().bias.mul_(args.head_init_scale) if getattr(args, 'head_init_bias', None) is not None: nn.init.constant_(self.model.get_classifier().bias, args.head_init_bias) if getattr(args, 'num_classes', None) is None: if not hasattr(self.model, 'num_classes'): raise AssertionError('Model must have `num_classes` attr if not set on cmd line/config.') args.num_classes = self.model.num_classes if getattr(args, 'grad_checkpointing', False): # only call if model supports it try: self.model.set_grad_checkpointing(enable=True) except Exception: pass if utils.is_primary(args): _logger.info( f'Model {safe_model_name(args.model_name)} created, param count:{sum([m.numel() for m in self.model.parameters()])}') # resolve data config for model data_config = resolve_data_config(vars(args), model=self.model, verbose=utils.is_primary(args)) self.data_config = data_config # setup augmentation batch splits for contrastive loss or split bn num_aug_splits = 0 if getattr(args, 'aug_splits', 0) > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits self.num_aug_splits = num_aug_splits # enable split bn (separate bn stats per batch-portion) if getattr(args, 'split_bn', False): assert num_aug_splits > 1 or getattr(args, 'resplit', False) self.model = convert_splitbn_model(self.model, max(num_aug_splits, 2)) # move model to device, enable channels last layout if set self.model.to(device=self.device) if getattr(args, 'channels_last', False): self.model.to(memory_format=torch.channels_last) # setup synchronized BatchNorm for distributed training if args.distributed and getattr(args, 'sync_bn', False): args.dist_bn = '' # disable dist_bn when sync BN active assert not getattr(args, 'split_bn', False) if has_apex and self.use_amp == 'apex': # Apex SyncBN used with Apex AMP # WARNING this won't currently work with models using BatchNormAct2d self.model = convert_syncbn_model(self.model) else: self.model = convert_sync_batchnorm(self.model) if utils.is_primary(args): _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') # torchscript checks if getattr(args, 'torchscript', False): assert not getattr(args, 'torchcompile', False) assert not self.use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not getattr(args, 'sync_bn', False), 'Cannot use SyncBatchNorm with torchscripted model' self.model = torch.jit.script(self.model) if not args.lr: global_batch_size = args.batch_size * args.world_size * args.grad_accum_steps batch_ratio = global_batch_size / args.lr_base_size if not args.lr_base_scale: on = args.opt.lower() args.lr_base_scale = 'sqrt' if any([o in on for o in ('ada', 'lamb')]) else 'linear' if args.lr_base_scale == 'sqrt': batch_ratio = batch_ratio ** 0.5 args.lr = args.lr_base * batch_ratio if utils.is_primary(args): _logger.info( f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) ' f'and effective global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.') # ------------------ Optimizer & AMP ------------------ def create_optimizer(self): args = self.args self.optimizer = create_optimizer_v2( self.model, **optimizer_kwargs(cfg=args), **(getattr(args, 'opt_kwargs', {}) or {}), ) def setup_amp_and_scaler(self): args = self.args # setup automatic mixed-precision (AMP) loss scaling and op casting self.amp_autocast = suppress # default: no-op self.loss_scaler = None if self.use_amp == 'apex': assert self.device.type == 'cuda' self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level='O1') self.loss_scaler = ApexScaler() if utils.is_primary(args): _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif self.use_amp == 'native': try: self.amp_autocast = partial(torch.autocast, device_type=self.device.type, dtype=self.amp_dtype) except (AttributeError, TypeError): # fallback to CUDA only AMP for PyTorch < 1.10 assert self.device.type == 'cuda' self.amp_autocast = torch.cuda.amp.autocast if self.device.type == 'cuda' and self.amp_dtype == torch.float16: # loss scaler only used for float16 (half) dtype, bfloat16 does not need it self.loss_scaler = NativeScaler() if utils.is_primary(args): _logger.info('Using native Torch AMP. Training in mixed precision.') else: if utils.is_primary(args): _logger.info('AMP not enabled. Training in float32.') # ------------------ Resume & EMA & DDP & Compile ------------------ def resume_if_needed(self): args = self.args self.resume_epoch = None if getattr(args, 'resume', None): self.resume_epoch = resume_checkpoint( self.model, args.resume, optimizer=None if getattr(args, 'no_resume_opt', False) else self.optimizer, loss_scaler=None if getattr(args, 'no_resume_opt', False) else self.loss_scaler, log_info=utils.is_primary(args), ) def setup_model_ema(self): args = self.args self.model_ema = None if getattr(args, 'model_ema', False): # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper self.model_ema = utils.ModelEmaV2( self.model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) if getattr(args, 'resume', None): # if a resume path provided, try to load EMA weights from it load_checkpoint(self.model_ema.module, args.resume, use_ema=True) def setup_distributed_and_compile(self): args = self.args if args.distributed: if has_apex and self.use_amp == 'apex': if utils.is_primary(args): _logger.info("Using NVIDIA APEX DistributedDataParallel.") self.model = ApexDDP(self.model, delay_allreduce=True) else: if utils.is_primary(args): _logger.info("Using native Torch DistributedDataParallel.") # device variable may be device id / torch.device depending on utils.init_distributed_device self.model = NativeDDP(self.model, device_ids=[self.device], broadcast_buffers=not getattr(args, 'no_ddp_bb', False)) # NOTE: EMA model does not need to be wrapped by DDP if getattr(args, 'torchcompile', False): assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.' self.model = torch.compile(self.model, backend=args.torchcompile) # ------------------ Data loaders ------------------ def process_dataloaders(self): args = self.args # setup mixup / cutmix self.collate_fn = None self.mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or getattr(args, 'cutmix_minmax', None) is not None if mixup_active: mixup_args = dict( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=getattr(args, 'cutmix_minmax', None), prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.num_classes ) if args.prefetcher: assert not self.num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) self.collate_fn = FastCollateMixup(**mixup_args) else: self.mixup_fn = Mixup(**mixup_args) self.loader_train, self.loader_eval = self.dataloaders['train'], self.dataloaders['valid'] self.dataset_train = self.loader_train.dataset # ------------------ Loss ------------------ def setup_losses(self): args = self.args num_aug_splits = self.num_aug_splits mixup_active = self.mixup_fn is not None or (self.collate_fn is not None) if getattr(args, 'jsd_loss', False): assert num_aug_splits > 1 # JSD only valid with aug splits set self.train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing) elif mixup_active: # smoothing is handled with mixup target transform which outputs sparse, soft targets if getattr(args, 'bce_loss', False): self.train_loss_fn = BinaryCrossEntropy(target_threshold=getattr(args, 'bce_target_thresh', None)) else: self.train_loss_fn = SoftTargetCrossEntropy() elif getattr(args, 'smoothing', 0): if getattr(args, 'bce_loss', False): self.train_loss_fn = BinaryCrossEntropy(smoothing=args.smoothing, target_threshold=getattr(args, 'bce_target_thresh', None)) else: self.train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: self.train_loss_fn = nn.CrossEntropyLoss() self.train_loss_fn = self.train_loss_fn.to(device=self.device) self.validate_loss_fn = nn.CrossEntropyLoss().to(device=self.device) # ------------------ Checkpoint, logging, saver ------------------ def setup_checkpoint_and_logging(self): args = self.args # setup checkpoint saver and eval metric tracking eval_metric = getattr(args, 'eval_metric', 'accuracy') self.best_metric = None self.best_epoch = None self.saver = None self.output_dir = None if utils.is_primary(args): if getattr(args, 'project_name', None): exp_name = args.project_name else: exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), safe_model_name(args.model_name), str(self.data_config['input_size'][-1]) ]) self.output_dir = os.path.join(args.output_dir, args.saved_models_dir) decreasing = True if eval_metric == 'loss' else False self.saver = CheckpointSaver( model=self.model, optimizer=self.optimizer, args=args, model_ema=self.model_ema, amp_scaler=self.loss_scaler, checkpoint_dir=self.output_dir, recovery_dir=self.output_dir, decreasing=decreasing, max_history=getattr(args, 'checkpoint_hist', 10) ) # write args to file # todo different from before NIKHIl args_text = yaml.safe_dump(vars(args)) with open(os.path.join(self.output_dir, 'args.yaml'), 'w') as f: f.write(args_text) if getattr(args, 'log_tb', False): self.writer = SummaryWriter(log_dir=self.output_dir) if utils.is_primary(args) and getattr(args, 'log_wandb', False): if has_wandb: wandb.init(project=getattr(args, 'experiment', None), config=args) else: _logger.warning( "You've requested to log metrics to wandb but package not found. " "Metrics not being logged to wandb, try `pip install wandb`") # ------------------ LR scheduler and start epoch ------------------ def setup_lr_scheduler_and_start(self): args = self.args updates_per_epoch = (len(self.loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps self.lr_scheduler, self.num_epochs = create_scheduler_v2( self.optimizer, **scheduler_kwargs(args), updates_per_epoch=updates_per_epoch, ) self.start_epoch = 0 if getattr(args, 'start_epoch', None) is not None: self.start_epoch = args.start_epoch elif self.resume_epoch is not None: self.start_epoch = self.resume_epoch if self.lr_scheduler is not None and self.start_epoch > 0: if getattr(args, 'sched_on_updates', False): self.lr_scheduler.step_update(self.start_epoch * updates_per_epoch) else: self.lr_scheduler.step(self.start_epoch) if utils.is_primary(args): _logger.info( f'Scheduled epochs: {self.num_epochs}. LR stepped per {"epoch" if self.lr_scheduler.t_in_epochs else "update"}.') # ------------------ Training loop ------------------ def train_loop(self): args = self.args try: for epoch in range(self.start_epoch, self.num_epochs): # set epoch for samplers if present if hasattr(self.dataset_train, 'set_epoch'): self.dataset_train.set_epoch(epoch) elif args.distributed and hasattr(self.loader_train.sampler, 'set_epoch'): self.loader_train.sampler.set_epoch(epoch) train_metrics = self.train_one_epoch( epoch, self.model, self.loader_train, self.optimizer, self.train_loss_fn, args, lr_scheduler=self.lr_scheduler, saver=self.saver, output_dir=self.output_dir, amp_autocast=self.amp_autocast, loss_scaler=self.loss_scaler, model_ema=self.model_ema, mixup_fn=self.mixup_fn ) # distributed bn sync if requested if args.distributed and getattr(args, 'dist_bn', None) in ('broadcast', 'reduce'): if utils.is_primary(args): _logger.info("Distributing BatchNorm running means and vars") utils.distribute_bn(self.model, args.world_size, args.dist_bn == 'reduce') eval_metrics = self.validate( self.model, self.loader_eval, self.validate_loss_fn, args, amp_autocast=self.amp_autocast, ) eval_metrics_unite = eval_metrics ema_eval_metrics = None if self.model_ema is not None and not getattr(args, 'model_ema_force_cpu', False): if args.distributed and getattr(args, 'dist_bn', None) in ('broadcast', 'reduce'): utils.distribute_bn(self.model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = self.validate( self.model_ema.module, self.loader_eval, self.validate_loss_fn, args, amp_autocast=self.amp_autocast, log_suffix=' (EMA)', ) # choose the best model between normal and EMA using eval_metric eval_metric = getattr(args, 'eval_metric', 'top1') # in some configs metric is 'loss' where smaller is better; original code compared using > if ema_eval_metrics[eval_metric] > eval_metrics[eval_metric]: eval_metrics_unite = ema_eval_metrics if getattr(args, 'dryrun', False): break if self.output_dir is not None: lrs = [param_group['lr'] for param_group in self.optimizer.param_groups] if getattr(args, 'log_tb', False) and self.writer is not None: for key, value in train_metrics.items(): self.writer.add_scalar('train/' + key, value, epoch) for key, value in eval_metrics_unite.items(): self.writer.add_scalar('eval/' + key, value, epoch) for i, lr in enumerate(lrs): self.writer.add_scalar(f'lr/{i}', lr, epoch) utils.update_summary( epoch, train_metrics, eval_metrics_unite, filename=os.path.join(self.output_dir, 'summary.csv'), lr=sum(lrs) / len(lrs), write_header=self.best_metric is None, log_wandb=getattr(args, 'log_wandb', False) and has_wandb, ) if self.saver is not None: # save proper checkpoint with eval metric eval_metric = getattr(args, 'eval_metric', 'top1') save_metric = eval_metrics.get(eval_metric, None) if ema_eval_metrics: save_metric_ema = ema_eval_metrics.get(eval_metric, -1) else: save_metric_ema = -1 self.best_metric, self.best_epoch = self.saver.save_checkpoint(epoch, metric=save_metric, metric_ema=save_metric_ema) if self.lr_scheduler is not None: # step LR for next epoch self.lr_scheduler.step(epoch + 1, eval_metrics_unite.get(getattr(args, 'eval_metric', 'top1'), 0)) except KeyboardInterrupt: pass if self.best_metric is not None: _logger.info('*** Best metric: {0} (epoch {1})'.format(self.best_metric, self.best_epoch)) # ------------------ train_one_epoch (member) ------------------ def train_one_epoch( self, epoch, model, loader, optimizer, loss_fn, args=None, device=None, lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress, loss_scaler=None, model_ema=None, mixup_fn=None, model_kd=None, ): # This function is intentionally left functionally the same as the original, # but accepts optional args/device and falls back to trainer attributes. if args is None: args = self.args if device is None: device = self.device if self.device is not None else (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')) if getattr(args, 'mixup_off_epoch', None) and epoch >= getattr(args, 'mixup_off_epoch'): if getattr(args, 'prefetcher', False) and hasattr(loader, 'mixup_enabled') and loader.mixup_enabled: loader.mixup_enabled = False elif mixup_fn is not None: mixup_fn.mixup_enabled = False second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order has_no_sync = hasattr(model, "no_sync") update_time_m = utils.AverageMeter() data_time_m = utils.AverageMeter() losses_m = utils.AverageMeter() model.train() accum_steps = args.grad_accum_steps last_accum_steps = len(loader) % accum_steps updates_per_epoch = (len(loader) + accum_steps - 1) // accum_steps num_updates = epoch * updates_per_epoch last_batch_idx = len(loader) - 1 last_batch_idx_to_accum = len(loader) - last_accum_steps data_start_time = update_start_time = time.time() optimizer.zero_grad() update_sample_count = 0 for batch_idx, (input, target) in enumerate(loader): last_batch = batch_idx == last_batch_idx need_update = last_batch or (batch_idx + 1) % accum_steps == 0 update_idx = batch_idx // accum_steps if batch_idx >= last_batch_idx_to_accum: accum_steps = last_accum_steps if not getattr(args, 'prefetcher', False): input, target = input.to(device), target.to(device) if mixup_fn is not None: input, target = mixup_fn(input, target) if getattr(args, 'channels_last', False): input = input.contiguous(memory_format=torch.channels_last) # multiply by accum steps to get equivalent for full update data_time_m.update(accum_steps * (time.time() - data_start_time)) def _forward(): with amp_autocast(): output = model(input) loss = loss_fn(output, target) if accum_steps > 1: loss /= accum_steps return loss def _backward(_loss): if loss_scaler is not None: loss_scaler( _loss, optimizer, clip_grad=getattr(args, 'clip_grad', None), clip_mode=getattr(args, 'clip_mode', None), parameters=model_parameters(model, exclude_head='agc' in getattr(args, 'clip_mode', '')), # original logic create_graph=second_order, need_update=need_update, ) else: _loss.backward(create_graph=second_order) if need_update: if getattr(args, 'clip_grad', None) is not None: utils.dispatch_clip_grad( model_parameters(model, exclude_head='agc' in getattr(args, 'clip_mode', '')), value=args.clip_grad, mode=getattr(args, 'clip_mode', None), ) optimizer.step() if has_no_sync and not need_update: with model.no_sync(): loss = _forward() _backward(loss) else: loss = _forward() _backward(loss) if not getattr(args, 'distributed', False): losses_m.update(loss.item() * accum_steps, input.size(0)) update_sample_count += input.size(0) if not need_update: data_start_time = time.time() continue num_updates += 1 optimizer.zero_grad() if model_ema is not None: model_ema.update(model) if getattr(args, 'synchronize_step', False) and device.type == 'cuda': torch.cuda.synchronize() time_now = time.time() update_time_m.update(time.time() - update_start_time) update_start_time = time_now if update_idx % getattr(args, 'log_interval', 10) == 0: lrl = [param_group['lr'] for param_group in optimizer.param_groups] lr = sum(lrl) / len(lrl) if getattr(args, 'distributed', False): reduced_loss = utils.reduce_tensor(loss.data, args.world_size) losses_m.update(reduced_loss.item() * accum_steps, input.size(0)) update_sample_count *= args.world_size if utils.is_primary(args): _logger.info( f'Train: {epoch} [{update_idx:>4d}/{updates_per_epoch} ' f'({100. * update_idx / (updates_per_epoch - 1):>3.0f}%)] ' f'Loss: {losses_m.val:#.3g} ({losses_m.avg:#.3g}) ' f'Time: {update_time_m.val:.3f}s, {update_sample_count / update_time_m.val:>7.2f}/s ' f'({update_time_m.avg:.3f}s, {update_sample_count / update_time_m.avg:>7.2f}/s) ' f'LR: {lr:.3e} ' f'Data: {data_time_m.val:.3f} ({data_time_m.avg:.3f})' ) if getattr(args, 'save_images', False) and output_dir: torchvision.utils.save_image( input, os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), padding=0, normalize=True ) if getattr(args, 'dryrun', False): break if saver is not None and getattr(args, 'recovery_interval', None) and ( (update_idx + 1) % getattr(args, 'recovery_interval') == 0): saver.save_recovery(epoch, batch_idx=update_idx) if lr_scheduler is not None: lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) update_sample_count = 0 data_start_time = time.time() # end for if hasattr(optimizer, 'sync_lookahead'): optimizer.sync_lookahead() return OrderedDict([('loss', losses_m.avg)]) # ------------------ validate (member) ------------------ def validate(self, model, loader, loss_fn, args=None, device=None, amp_autocast=suppress, log_suffix=''): if args is None: args = self.args if device is None: device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') batch_time_m = utils.AverageMeter() losses_m = utils.AverageMeter() top1_m = utils.AverageMeter() top5_m = utils.AverageMeter() model.eval() end = time.time() last_idx = len(loader) - 1 with torch.no_grad(): for batch_idx, (input, target) in enumerate(loader): last_batch = batch_idx == last_idx if not getattr(args, 'prefetcher', False): input = input.to(device) target = target.to(device) if getattr(args, 'channels_last', False): input = input.contiguous(memory_format=torch.channels_last) with amp_autocast(): output = model(input) if isinstance(output, (tuple, list)): output = output[0] # augmentation reduction reduce_factor = getattr(args, 'tta', 1) if reduce_factor > 1: output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) target = target[0:target.size(0):reduce_factor] loss = loss_fn(output, target) acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) if getattr(args, 'distributed', False): reduced_loss = utils.reduce_tensor(loss.data, args.world_size) acc1 = utils.reduce_tensor(acc1, args.world_size) acc5 = utils.reduce_tensor(acc5, args.world_size) else: reduced_loss = loss.data if device.type == 'cuda': torch.cuda.synchronize() losses_m.update(reduced_loss.item(), input.size(0)) top1_m.update(acc1.item(), output.size(0)) top5_m.update(acc5.item(), output.size(0)) batch_time_m.update(time.time() - end) end = time.time() if utils.is_primary(args) and (last_batch or batch_idx % getattr(args, 'log_interval', 10) == 0): log_name = 'Test' + log_suffix _logger.info( f'{log_name}: [{batch_idx:>4d}/{last_idx}] ' f'Time: {batch_time_m.val:.3f} ({batch_time_m.avg:.3f}) ' f'Loss: {losses_m.val:>7.3f} ({losses_m.avg:>6.3f}) ' f'Acc@1: {top1_m.val:>7.3f} ({top1_m.avg:>7.3f}) ' f'Acc@5: {top5_m.val:>7.3f} ({top5_m.avg:>7.3f})' ) if getattr(args, 'dryrun', False): break metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) return metrics