ProArd / proard /classification /elastic_nn /training /progressive_shrinking.py
smi08's picture
Upload folder using huggingface_hub
188f311 verified
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import torch.nn as nn
import random
import time
import torch
import torch.nn.functional as F
from tqdm import tqdm
from attacks.utils import ctx_noparamgrad_and_eval
from robust_loss.rslad import rslad_inner_loss,kl_loss
from robust_loss.trades import trades_loss
from attacks import create_attack
import copy
from proard.utils import AverageMeter, cross_entropy_loss_with_soft_target
from proard.utils import (
DistributedMetric,
list_mean,
subset_mean,
val2list,
MyRandomResizedCrop,
)
from proard.classification.run_manager import DistributedRunManager
__all__ = [
"validate",
"train_one_epoch",
"train",
"load_models",
"train_elastic_depth",
"train_elastic_expand",
"train_elastic_width_mult",
]
def validate(
run_manager,
epoch=0,
is_test=False,
image_size_list=None,
ks_list=None,
expand_ratio_list=None,
depth_list=None,
width_mult_list=None,
additional_setting=None,
):
dynamic_net = run_manager.net
if isinstance(dynamic_net, nn.DataParallel):
dynamic_net = dynamic_net.module
dynamic_net.eval()
if image_size_list is None:
image_size_list = val2list(run_manager.run_config.data_provider.image_size, 1)
if ks_list is None:
ks_list = dynamic_net.ks_list
if expand_ratio_list is None:
expand_ratio_list = dynamic_net.expand_ratio_list
if depth_list is None:
depth_list = dynamic_net.depth_list
if width_mult_list is not None:
if "width_mult_list" in dynamic_net.__dict__:
width_mult_list = list(range(len(dynamic_net.width_mult_list)))
else:
width_mult_list = [0]
subnet_settings = []
for d in depth_list:
for e in expand_ratio_list:
for k in ks_list:
for w in width_mult_list:
for img_size in image_size_list:
subnet_settings.append(
[
{
"image_size": img_size,
"d": d,
"e": e,
"ks": k,
"w": w,
},
"R%s-D%s-E%s-K%s-W%s" % (img_size, d, e, k, w),
]
)
if additional_setting is not None:
subnet_settings += additional_setting
losses_of_subnets, top1_of_subnets, top5_of_subnets , robust1_of_subnets , robust5_of_subnets = [], [], [],[],[]
valid_log = ""
for setting, name in subnet_settings:
run_manager.write_log(
"-" * 30 + " Validate %s " % name + "-" * 30, "train", should_print=False
)
run_manager.run_config.data_provider.assign_active_img_size(
setting.pop("image_size")
)
dynamic_net.set_active_subnet(**setting)
run_manager.write_log(dynamic_net.module_str, "train", should_print=False)
run_manager.reset_running_statistics(dynamic_net)
loss, (top1, top5,robust1,robust5) = run_manager.validate(
epoch=epoch, is_test=is_test, run_str=name, net=dynamic_net
)
losses_of_subnets.append(loss)
top1_of_subnets.append(top1)
top5_of_subnets.append(top5)
robust1_of_subnets.append(robust1)
robust5_of_subnets.append(robust5)
valid_log += "%s (%.3f) (%.3f), " % (name, top1,robust1)
return (
list_mean(losses_of_subnets),
list_mean(top1_of_subnets),
list_mean(top5_of_subnets),
list_mean(robust1_of_subnets),
list_mean(robust5_of_subnets),
valid_log,
)
def train_one_epoch(run_manager, args, epoch, warmup_epochs=0, warmup_lr=0):
dynamic_net = run_manager.network
distributed = isinstance(run_manager, DistributedRunManager)
# switch to train mode
dynamic_net.train()
if distributed:
run_manager.run_config.train_loader.sampler.set_epoch(epoch)
MyRandomResizedCrop.EPOCH = epoch
nBatch = len(run_manager.run_config.train_loader)
data_time = AverageMeter()
losses = DistributedMetric("train_loss") if distributed else AverageMeter()
metric_dict = run_manager.get_metric_dict()
with tqdm(
total=nBatch,
desc="Train Epoch #{}".format(epoch + 1),
disable=distributed and not run_manager.is_root,
) as t:
end = time.time()
subnet_str = ""
j=0
for _ in range(args.dynamic_batch_size):
# set random seed before sampling
subnet_seed = int("%d%.3d%.3d" % (epoch * nBatch + j, _, 0))
random.seed(subnet_seed)
subnet_settings = dynamic_net.sample_active_subnet()
subnet_str += (
"%d: " % _
+ ",".join(
[
"%s_%s"
% (
key,
"%.1f" % subset_mean(val, 0)
if isinstance(val, list)
else val,
)
for key, val in subnet_settings.items()
]
)
+ " || "
)
for i, (images, labels) in enumerate(run_manager.run_config.train_loader):
MyRandomResizedCrop.BATCH = i
data_time.update(time.time() - end)
if epoch < warmup_epochs:
new_lr = run_manager.run_config.warmup_adjust_learning_rate(
run_manager.optimizer,
warmup_epochs * nBatch,
nBatch,
epoch,
i,
warmup_lr,
)
else:
new_lr = run_manager.run_config.adjust_learning_rate(
run_manager.optimizer, epoch - warmup_epochs, i, nBatch
)
images, labels = images.cuda(), labels.cuda()
target = labels
# soft target
if args.kd_ratio > 0:
args.teacher_model.eval()
with torch.no_grad():
soft_logits = args.teacher_model(images).detach()
soft_label = F.softmax(soft_logits, dim=1)
# clean gradients
dynamic_net.zero_grad()
loss_of_subnets = []
# compute output
output = dynamic_net(images)
if args.kd_ratio == 0:
if run_manager.run_config.robust_mode:
loss = run_manager.train_criterion(dynamic_net,images,labels,run_manager.optimizer,run_manager.run_config.step_size_train,run_manager.run_config.epsilon_train,run_manager.run_config.num_steps_train,run_manager.run_config.beta_train,run_manager.run_config.distance_train)
loss_type = run_manager.run_config.train_criterion_loss.__name__
else:
loss = torch.nn.CrossEntropyLoss(output,labels)
loss_type = 'ce'
else:
if run_manager.run_config.robust_mode:
loss = run_manager.kd_criterion(args.teacher_model,dynamic_net,images,labels,run_manager.optimizer,run_manager.run_config.step_size_train,run_manager.run_config.epsilon_train,run_manager.run_config.num_steps_train,run_manager.run_config.beta_train)
loss_type = run_manager.run_config.kd_criterion_loss.__name__
else:
if args.kd_type == "ce":
kd_loss = cross_entropy_loss_with_soft_target(
output, soft_label
)
else:
kd_loss = F.mse_loss(output, soft_logits)
loss = args.kd_ratio * kd_loss + loss
loss_type = "%.1fkd+ce" % args.kd_ratio
# measure accuracy and record loss
loss_of_subnets.append(loss)
run_manager.update_metric(metric_dict, output,output, target)
loss.backward()
run_manager.optimizer.step()
losses.update(list_mean(loss_of_subnets), images.size(0))
t.set_postfix(
{
"loss": losses.avg.item(),
**run_manager.get_metric_vals(metric_dict, return_dict=True),
"R": images.size(2),
"lr": new_lr,
"loss_type": loss_type,
"seed": str(subnet_seed),
"str": subnet_str,
"data_time": data_time.avg,
}
)
t.update(1)
end = time.time()
j+=1
return losses.avg.item(), run_manager.get_metric_vals(metric_dict)
def train(run_manager, args, validate_func=None):
distributed = isinstance(run_manager, DistributedRunManager)
if validate_func is None:
validate_func = validate
for epoch in range(
run_manager.start_epoch, run_manager.run_config.n_epochs + args.warmup_epochs
):
train_loss, (train_top1, train_top5 , train_robust1 , train_robust5) = train_one_epoch(
run_manager, args, epoch, args.warmup_epochs, args.warmup_lr
)
if (epoch + 1) % args.validation_frequency == 0:
val_loss, val_acc, val_acc5, val_robust1, val_robust5, _val_log = validate_func(
run_manager, epoch=epoch, is_test=True
)
# best_acc
is_best = val_acc > run_manager.best_acc
is_best_robust = val_robust1 > run_manager.best_robustness
run_manager.best_acc = max(run_manager.best_acc, val_acc)
run_manager.best_robustness = max(run_manager.best_robustness, val_robust1)
if not distributed or run_manager.is_root:
val_log = (
"Valid [{0}/{1}] loss={2:.3f}, top-1={3:.3f} ({4:.3f}) , robust-1 = {4:.3f} ({5:.3f}) ".format(
epoch + 1 - args.warmup_epochs,
run_manager.run_config.n_epochs,
val_loss,
val_acc,
run_manager.best_acc,
val_robust1,
run_manager.best_robustness,
)
)
val_log += ", Train top-1 {top1:.3f}, Train robust-1 {robust1:.3f}, Train loss {loss:.3f}\t".format(
top1=train_top1, robust1 = train_robust1, loss=train_loss
)
val_log += _val_log
run_manager.write_log(val_log, "valid", should_print=False)
run_manager.save_model(
{
"epoch": epoch,
"best_acc": run_manager.best_acc,
"optimizer": run_manager.optimizer.state_dict(),
"state_dict": run_manager.network.state_dict(),
},
is_best=is_best,
)
def load_models(run_manager, dynamic_net, model_path=None):
# specify init path
init = torch.load(model_path, map_location="cpu")["state_dict"]
dynamic_net.load_state_dict(init)
run_manager.write_log("Loaded init from %s" % model_path, "valid")
def train_elastic_depth(train_func, run_manager, args, validate_func_dict):
dynamic_net = run_manager.net
if isinstance(dynamic_net, nn.DataParallel):
dynamic_net = dynamic_net.module
depth_stage_list = dynamic_net.depth_list.copy()
depth_stage_list.sort(reverse=True)
n_stages = len(depth_stage_list) - 1
current_stage = n_stages - 1
# load pretrained models
if run_manager.start_epoch == 0 and not args.resume:
validate_func_dict["depth_list"] = sorted(dynamic_net.depth_list)
load_models(run_manager, dynamic_net, model_path=args.dyn_checkpoint_path)
# validate after loading weights
run_manager.write_log(
"%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%s"
% validate(run_manager, is_test=True, **validate_func_dict),
"valid",
)
else:
assert args.resume
run_manager.write_log(
"-" * 30
+ "Supporting Elastic Depth: %s -> %s"
% (depth_stage_list[: current_stage + 1], depth_stage_list[: current_stage + 2])
+ "-" * 30,
"valid",
)
# add depth list constraints
if (
len(set(dynamic_net.ks_list)) == 1
and len(set(dynamic_net.expand_ratio_list)) == 1
):
validate_func_dict["depth_list"] = depth_stage_list
else:
validate_func_dict["depth_list"] = sorted(
{min(depth_stage_list), max(depth_stage_list)}
)
# train
train_func(
run_manager,
args,
lambda _run_manager, epoch, is_test: validate(
_run_manager, epoch, is_test, **validate_func_dict
),
)
def train_elastic_expand(train_func, run_manager, args, validate_func_dict):
dynamic_net = run_manager.net
if isinstance(dynamic_net, nn.DataParallel):
dynamic_net = dynamic_net.module
expand_stage_list = dynamic_net.expand_ratio_list.copy()
expand_stage_list.sort(reverse=True)
n_stages = len(expand_stage_list) - 1
current_stage = n_stages - 1
# load pretrained models
if run_manager.start_epoch == 0 and not args.resume:
validate_func_dict["expand_ratio_list"] = sorted(dynamic_net.expand_ratio_list)
load_models(run_manager, dynamic_net, model_path=args.dyn_checkpoint_path)
dynamic_net.re_organize_middle_weights(expand_ratio_stage=current_stage)
run_manager.write_log(
"%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%s"
% validate(run_manager, is_test=True, **validate_func_dict),
"valid",
)
else:
assert args.resume
run_manager.write_log(
"-" * 30
+ "Supporting Elastic Expand Ratio: %s -> %s"
% (
expand_stage_list[: current_stage + 1],
expand_stage_list[: current_stage + 2],
)
+ "-" * 30,
"valid",
)
if len(set(dynamic_net.ks_list)) == 1 and len(set(dynamic_net.depth_list)) == 1:
validate_func_dict["expand_ratio_list"] = expand_stage_list
else:
validate_func_dict["expand_ratio_list"] = sorted(
{min(expand_stage_list), max(expand_stage_list)}
)
# train
train_func(
run_manager,
args,
lambda _run_manager, epoch, is_test: validate(
_run_manager, epoch, is_test, **validate_func_dict
),
)
def train_elastic_width_mult(train_func, run_manager, args, validate_func_dict):
dynamic_net = run_manager.net
if isinstance(dynamic_net, nn.DataParallel):
dynamic_net = dynamic_net.module
width_stage_list = dynamic_net.width_mult_list.copy()
width_stage_list.sort(reverse=True)
n_stages = len(width_stage_list) - 1
current_stage = n_stages - 1
if run_manager.start_epoch == 0 and not args.resume:
load_models(run_manager, dynamic_net, model_path=args.dyn_checkpoint_path)
if current_stage == 0:
dynamic_net.re_organize_middle_weights(
expand_ratio_stage=len(dynamic_net.expand_ratio_list) - 1
)
run_manager.write_log(
"reorganize_middle_weights (expand_ratio_stage=%d)"
% (len(dynamic_net.expand_ratio_list) - 1),
"valid",
)
try:
dynamic_net.re_organize_outer_weights()
run_manager.write_log("reorganize_outer_weights", "valid")
except Exception:
pass
validate_func_dict["width_mult_list"] = sorted({0, len(width_stage_list) - 1})
run_manager.write_log(
"%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%s"
% validate(run_manager, is_test=True, **validate_func_dict),
"valid",
)
else:
assert args.resume
run_manager.write_log(
"-" * 30
+ "Supporting Elastic Width Mult: %s -> %s"
% (width_stage_list[: current_stage + 1], width_stage_list[: current_stage + 2])
+ "-" * 30,
"valid",
)
validate_func_dict["width_mult_list"] = sorted({0, len(width_stage_list) - 1})
# train
train_func(
run_manager,
args,
lambda _run_manager, epoch, is_test: validate(
_run_manager, epoch, is_test, **validate_func_dict
),
)