| | |
| | |
| | |
| |
|
| | import torch.nn as nn |
| | import random |
| | import time |
| | import torch |
| | import torch.nn.functional as F |
| | from tqdm import tqdm |
| | from attacks.utils import ctx_noparamgrad_and_eval |
| | from robust_loss.rslad import rslad_inner_loss,kl_loss |
| | from robust_loss.trades import trades_loss |
| | from attacks import create_attack |
| | import copy |
| | from proard.utils import AverageMeter, cross_entropy_loss_with_soft_target |
| | from proard.utils import ( |
| | DistributedMetric, |
| | list_mean, |
| | subset_mean, |
| | val2list, |
| | MyRandomResizedCrop, |
| | ) |
| | from proard.classification.run_manager import DistributedRunManager |
| |
|
| | __all__ = [ |
| | "validate", |
| | "train_one_epoch", |
| | "train", |
| | "load_models", |
| | "train_elastic_depth", |
| | "train_elastic_expand", |
| | "train_elastic_width_mult", |
| | ] |
| |
|
| |
|
| | def validate( |
| | run_manager, |
| | epoch=0, |
| | is_test=False, |
| | image_size_list=None, |
| | ks_list=None, |
| | expand_ratio_list=None, |
| | depth_list=None, |
| | width_mult_list=None, |
| | additional_setting=None, |
| | ): |
| | dynamic_net = run_manager.net |
| | if isinstance(dynamic_net, nn.DataParallel): |
| | dynamic_net = dynamic_net.module |
| |
|
| | dynamic_net.eval() |
| |
|
| | if image_size_list is None: |
| | image_size_list = val2list(run_manager.run_config.data_provider.image_size, 1) |
| | if ks_list is None: |
| | ks_list = dynamic_net.ks_list |
| | if expand_ratio_list is None: |
| | expand_ratio_list = dynamic_net.expand_ratio_list |
| | if depth_list is None: |
| | depth_list = dynamic_net.depth_list |
| | if width_mult_list is not None: |
| | if "width_mult_list" in dynamic_net.__dict__: |
| | width_mult_list = list(range(len(dynamic_net.width_mult_list))) |
| | else: |
| | width_mult_list = [0] |
| |
|
| | subnet_settings = [] |
| | for d in depth_list: |
| | for e in expand_ratio_list: |
| | for k in ks_list: |
| | for w in width_mult_list: |
| | for img_size in image_size_list: |
| | subnet_settings.append( |
| | [ |
| | { |
| | "image_size": img_size, |
| | "d": d, |
| | "e": e, |
| | "ks": k, |
| | "w": w, |
| | }, |
| | "R%s-D%s-E%s-K%s-W%s" % (img_size, d, e, k, w), |
| | ] |
| | ) |
| | if additional_setting is not None: |
| | subnet_settings += additional_setting |
| |
|
| | losses_of_subnets, top1_of_subnets, top5_of_subnets , robust1_of_subnets , robust5_of_subnets = [], [], [],[],[] |
| |
|
| | valid_log = "" |
| | for setting, name in subnet_settings: |
| | run_manager.write_log( |
| | "-" * 30 + " Validate %s " % name + "-" * 30, "train", should_print=False |
| | ) |
| | run_manager.run_config.data_provider.assign_active_img_size( |
| | setting.pop("image_size") |
| | ) |
| | dynamic_net.set_active_subnet(**setting) |
| | run_manager.write_log(dynamic_net.module_str, "train", should_print=False) |
| |
|
| | run_manager.reset_running_statistics(dynamic_net) |
| | loss, (top1, top5,robust1,robust5) = run_manager.validate( |
| | epoch=epoch, is_test=is_test, run_str=name, net=dynamic_net |
| | ) |
| | losses_of_subnets.append(loss) |
| | top1_of_subnets.append(top1) |
| | top5_of_subnets.append(top5) |
| | robust1_of_subnets.append(robust1) |
| | robust5_of_subnets.append(robust5) |
| | valid_log += "%s (%.3f) (%.3f), " % (name, top1,robust1) |
| |
|
| | return ( |
| | list_mean(losses_of_subnets), |
| | list_mean(top1_of_subnets), |
| | list_mean(top5_of_subnets), |
| | list_mean(robust1_of_subnets), |
| | list_mean(robust5_of_subnets), |
| | valid_log, |
| | ) |
| |
|
| |
|
| | def train_one_epoch(run_manager, args, epoch, warmup_epochs=0, warmup_lr=0): |
| | dynamic_net = run_manager.network |
| | distributed = isinstance(run_manager, DistributedRunManager) |
| |
|
| | |
| | dynamic_net.train() |
| | if distributed: |
| | run_manager.run_config.train_loader.sampler.set_epoch(epoch) |
| | MyRandomResizedCrop.EPOCH = epoch |
| |
|
| | nBatch = len(run_manager.run_config.train_loader) |
| |
|
| | data_time = AverageMeter() |
| | losses = DistributedMetric("train_loss") if distributed else AverageMeter() |
| | metric_dict = run_manager.get_metric_dict() |
| |
|
| | with tqdm( |
| | total=nBatch, |
| | desc="Train Epoch #{}".format(epoch + 1), |
| | disable=distributed and not run_manager.is_root, |
| | ) as t: |
| | end = time.time() |
| | subnet_str = "" |
| | j=0 |
| | for _ in range(args.dynamic_batch_size): |
| | |
| | subnet_seed = int("%d%.3d%.3d" % (epoch * nBatch + j, _, 0)) |
| | random.seed(subnet_seed) |
| | subnet_settings = dynamic_net.sample_active_subnet() |
| | subnet_str += ( |
| | "%d: " % _ |
| | + ",".join( |
| | [ |
| | "%s_%s" |
| | % ( |
| | key, |
| | "%.1f" % subset_mean(val, 0) |
| | if isinstance(val, list) |
| | else val, |
| | ) |
| | for key, val in subnet_settings.items() |
| | ] |
| | ) |
| | + " || " |
| | ) |
| |
|
| | for i, (images, labels) in enumerate(run_manager.run_config.train_loader): |
| | MyRandomResizedCrop.BATCH = i |
| | data_time.update(time.time() - end) |
| | if epoch < warmup_epochs: |
| | new_lr = run_manager.run_config.warmup_adjust_learning_rate( |
| | run_manager.optimizer, |
| | warmup_epochs * nBatch, |
| | nBatch, |
| | epoch, |
| | i, |
| | warmup_lr, |
| | ) |
| | else: |
| | new_lr = run_manager.run_config.adjust_learning_rate( |
| | run_manager.optimizer, epoch - warmup_epochs, i, nBatch |
| | ) |
| |
|
| | images, labels = images.cuda(), labels.cuda() |
| | target = labels |
| |
|
| | |
| | if args.kd_ratio > 0: |
| | args.teacher_model.eval() |
| | with torch.no_grad(): |
| | soft_logits = args.teacher_model(images).detach() |
| | soft_label = F.softmax(soft_logits, dim=1) |
| |
|
| | |
| | dynamic_net.zero_grad() |
| |
|
| | loss_of_subnets = [] |
| | |
| | |
| | |
| | output = dynamic_net(images) |
| | |
| | if args.kd_ratio == 0: |
| | if run_manager.run_config.robust_mode: |
| | loss = run_manager.train_criterion(dynamic_net,images,labels,run_manager.optimizer,run_manager.run_config.step_size_train,run_manager.run_config.epsilon_train,run_manager.run_config.num_steps_train,run_manager.run_config.beta_train,run_manager.run_config.distance_train) |
| | loss_type = run_manager.run_config.train_criterion_loss.__name__ |
| | else: |
| | loss = torch.nn.CrossEntropyLoss(output,labels) |
| | loss_type = 'ce' |
| | else: |
| | if run_manager.run_config.robust_mode: |
| | loss = run_manager.kd_criterion(args.teacher_model,dynamic_net,images,labels,run_manager.optimizer,run_manager.run_config.step_size_train,run_manager.run_config.epsilon_train,run_manager.run_config.num_steps_train,run_manager.run_config.beta_train) |
| | loss_type = run_manager.run_config.kd_criterion_loss.__name__ |
| | else: |
| | if args.kd_type == "ce": |
| | kd_loss = cross_entropy_loss_with_soft_target( |
| | output, soft_label |
| | ) |
| | else: |
| | kd_loss = F.mse_loss(output, soft_logits) |
| | loss = args.kd_ratio * kd_loss + loss |
| | loss_type = "%.1fkd+ce" % args.kd_ratio |
| | |
| | loss_of_subnets.append(loss) |
| | run_manager.update_metric(metric_dict, output,output, target) |
| |
|
| | loss.backward() |
| | run_manager.optimizer.step() |
| |
|
| | losses.update(list_mean(loss_of_subnets), images.size(0)) |
| |
|
| | t.set_postfix( |
| | { |
| | "loss": losses.avg.item(), |
| | **run_manager.get_metric_vals(metric_dict, return_dict=True), |
| | "R": images.size(2), |
| | "lr": new_lr, |
| | "loss_type": loss_type, |
| | "seed": str(subnet_seed), |
| | "str": subnet_str, |
| | "data_time": data_time.avg, |
| | } |
| | ) |
| | t.update(1) |
| | end = time.time() |
| | j+=1 |
| | return losses.avg.item(), run_manager.get_metric_vals(metric_dict) |
| |
|
| |
|
| | def train(run_manager, args, validate_func=None): |
| | distributed = isinstance(run_manager, DistributedRunManager) |
| | if validate_func is None: |
| | validate_func = validate |
| |
|
| | for epoch in range( |
| | run_manager.start_epoch, run_manager.run_config.n_epochs + args.warmup_epochs |
| | ): |
| | train_loss, (train_top1, train_top5 , train_robust1 , train_robust5) = train_one_epoch( |
| | run_manager, args, epoch, args.warmup_epochs, args.warmup_lr |
| | ) |
| |
|
| | if (epoch + 1) % args.validation_frequency == 0: |
| | val_loss, val_acc, val_acc5, val_robust1, val_robust5, _val_log = validate_func( |
| | run_manager, epoch=epoch, is_test=True |
| | ) |
| | |
| | is_best = val_acc > run_manager.best_acc |
| | is_best_robust = val_robust1 > run_manager.best_robustness |
| | run_manager.best_acc = max(run_manager.best_acc, val_acc) |
| | run_manager.best_robustness = max(run_manager.best_robustness, val_robust1) |
| | if not distributed or run_manager.is_root: |
| | val_log = ( |
| | "Valid [{0}/{1}] loss={2:.3f}, top-1={3:.3f} ({4:.3f}) , robust-1 = {4:.3f} ({5:.3f}) ".format( |
| | epoch + 1 - args.warmup_epochs, |
| | run_manager.run_config.n_epochs, |
| | val_loss, |
| | val_acc, |
| | run_manager.best_acc, |
| | val_robust1, |
| | run_manager.best_robustness, |
| | ) |
| | ) |
| | val_log += ", Train top-1 {top1:.3f}, Train robust-1 {robust1:.3f}, Train loss {loss:.3f}\t".format( |
| | top1=train_top1, robust1 = train_robust1, loss=train_loss |
| | ) |
| | val_log += _val_log |
| | run_manager.write_log(val_log, "valid", should_print=False) |
| |
|
| | run_manager.save_model( |
| | { |
| | "epoch": epoch, |
| | "best_acc": run_manager.best_acc, |
| | "optimizer": run_manager.optimizer.state_dict(), |
| | "state_dict": run_manager.network.state_dict(), |
| | }, |
| | is_best=is_best, |
| | ) |
| |
|
| |
|
| | def load_models(run_manager, dynamic_net, model_path=None): |
| | |
| | init = torch.load(model_path, map_location="cpu")["state_dict"] |
| | dynamic_net.load_state_dict(init) |
| | run_manager.write_log("Loaded init from %s" % model_path, "valid") |
| |
|
| |
|
| | def train_elastic_depth(train_func, run_manager, args, validate_func_dict): |
| | dynamic_net = run_manager.net |
| | if isinstance(dynamic_net, nn.DataParallel): |
| | dynamic_net = dynamic_net.module |
| |
|
| | depth_stage_list = dynamic_net.depth_list.copy() |
| | depth_stage_list.sort(reverse=True) |
| | n_stages = len(depth_stage_list) - 1 |
| | current_stage = n_stages - 1 |
| |
|
| | |
| | if run_manager.start_epoch == 0 and not args.resume: |
| | validate_func_dict["depth_list"] = sorted(dynamic_net.depth_list) |
| |
|
| | load_models(run_manager, dynamic_net, model_path=args.dyn_checkpoint_path) |
| | |
| | run_manager.write_log( |
| | "%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%s" |
| | % validate(run_manager, is_test=True, **validate_func_dict), |
| | "valid", |
| | ) |
| | else: |
| | assert args.resume |
| |
|
| | run_manager.write_log( |
| | "-" * 30 |
| | + "Supporting Elastic Depth: %s -> %s" |
| | % (depth_stage_list[: current_stage + 1], depth_stage_list[: current_stage + 2]) |
| | + "-" * 30, |
| | "valid", |
| | ) |
| | |
| | if ( |
| | len(set(dynamic_net.ks_list)) == 1 |
| | and len(set(dynamic_net.expand_ratio_list)) == 1 |
| | ): |
| | validate_func_dict["depth_list"] = depth_stage_list |
| | else: |
| | validate_func_dict["depth_list"] = sorted( |
| | {min(depth_stage_list), max(depth_stage_list)} |
| | ) |
| |
|
| | |
| | train_func( |
| | run_manager, |
| | args, |
| | lambda _run_manager, epoch, is_test: validate( |
| | _run_manager, epoch, is_test, **validate_func_dict |
| | ), |
| | ) |
| |
|
| |
|
| | def train_elastic_expand(train_func, run_manager, args, validate_func_dict): |
| | dynamic_net = run_manager.net |
| | if isinstance(dynamic_net, nn.DataParallel): |
| | dynamic_net = dynamic_net.module |
| |
|
| | expand_stage_list = dynamic_net.expand_ratio_list.copy() |
| | expand_stage_list.sort(reverse=True) |
| | n_stages = len(expand_stage_list) - 1 |
| | current_stage = n_stages - 1 |
| |
|
| | |
| | if run_manager.start_epoch == 0 and not args.resume: |
| | validate_func_dict["expand_ratio_list"] = sorted(dynamic_net.expand_ratio_list) |
| |
|
| | load_models(run_manager, dynamic_net, model_path=args.dyn_checkpoint_path) |
| | dynamic_net.re_organize_middle_weights(expand_ratio_stage=current_stage) |
| | run_manager.write_log( |
| | "%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%s" |
| | % validate(run_manager, is_test=True, **validate_func_dict), |
| | "valid", |
| | ) |
| | else: |
| | assert args.resume |
| |
|
| | run_manager.write_log( |
| | "-" * 30 |
| | + "Supporting Elastic Expand Ratio: %s -> %s" |
| | % ( |
| | expand_stage_list[: current_stage + 1], |
| | expand_stage_list[: current_stage + 2], |
| | ) |
| | + "-" * 30, |
| | "valid", |
| | ) |
| | if len(set(dynamic_net.ks_list)) == 1 and len(set(dynamic_net.depth_list)) == 1: |
| | validate_func_dict["expand_ratio_list"] = expand_stage_list |
| | else: |
| | validate_func_dict["expand_ratio_list"] = sorted( |
| | {min(expand_stage_list), max(expand_stage_list)} |
| | ) |
| |
|
| | |
| | train_func( |
| | run_manager, |
| | args, |
| | lambda _run_manager, epoch, is_test: validate( |
| | _run_manager, epoch, is_test, **validate_func_dict |
| | ), |
| | ) |
| |
|
| |
|
| | def train_elastic_width_mult(train_func, run_manager, args, validate_func_dict): |
| | dynamic_net = run_manager.net |
| | if isinstance(dynamic_net, nn.DataParallel): |
| | dynamic_net = dynamic_net.module |
| |
|
| | width_stage_list = dynamic_net.width_mult_list.copy() |
| | width_stage_list.sort(reverse=True) |
| | n_stages = len(width_stage_list) - 1 |
| | current_stage = n_stages - 1 |
| |
|
| | if run_manager.start_epoch == 0 and not args.resume: |
| | load_models(run_manager, dynamic_net, model_path=args.dyn_checkpoint_path) |
| | if current_stage == 0: |
| | dynamic_net.re_organize_middle_weights( |
| | expand_ratio_stage=len(dynamic_net.expand_ratio_list) - 1 |
| | ) |
| | run_manager.write_log( |
| | "reorganize_middle_weights (expand_ratio_stage=%d)" |
| | % (len(dynamic_net.expand_ratio_list) - 1), |
| | "valid", |
| | ) |
| | try: |
| | dynamic_net.re_organize_outer_weights() |
| | run_manager.write_log("reorganize_outer_weights", "valid") |
| | except Exception: |
| | pass |
| | validate_func_dict["width_mult_list"] = sorted({0, len(width_stage_list) - 1}) |
| | run_manager.write_log( |
| | "%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%s" |
| | % validate(run_manager, is_test=True, **validate_func_dict), |
| | "valid", |
| | ) |
| | else: |
| | assert args.resume |
| |
|
| | run_manager.write_log( |
| | "-" * 30 |
| | + "Supporting Elastic Width Mult: %s -> %s" |
| | % (width_stage_list[: current_stage + 1], width_stage_list[: current_stage + 2]) |
| | + "-" * 30, |
| | "valid", |
| | ) |
| | validate_func_dict["width_mult_list"] = sorted({0, len(width_stage_list) - 1}) |
| |
|
| | |
| | train_func( |
| | run_manager, |
| | args, |
| | lambda _run_manager, epoch, is_test: validate( |
| | _run_manager, epoch, is_test, **validate_func_dict |
| | ), |
| | ) |
| |
|