|
|
import sys |
|
|
import datetime |
|
|
import random |
|
|
import numpy as np |
|
|
import time |
|
|
import torch |
|
|
import torch.backends.cudnn as cudnn |
|
|
import json |
|
|
|
|
|
from pathlib import Path |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
|
|
from timm.data import Mixup |
|
|
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy |
|
|
from timm.scheduler import create_scheduler |
|
|
from timm.optim import create_optimizer |
|
|
from timm.utils import NativeScaler, get_state_dict, ModelEma |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from methods.engine_StyleAdv_ViT import train_one_epoch_styleAdv, evaluate |
|
|
|
|
|
import utils.deit_util as utils |
|
|
|
|
|
|
|
|
from data.pmf_datasets import get_loaders_withGlobalID |
|
|
|
|
|
from utils.args import get_args_parser |
|
|
|
|
|
|
|
|
from methods.load_ViT_models import get_model |
|
|
|
|
|
|
|
|
|
|
|
lr_classifier = 0.001 |
|
|
|
|
|
|
|
|
def main(args): |
|
|
utils.init_distributed_mode(args) |
|
|
|
|
|
print(args) |
|
|
device = torch.device(args.device) |
|
|
|
|
|
|
|
|
seed = args.seed + utils.get_rank() |
|
|
args.seed = seed |
|
|
torch.manual_seed(seed) |
|
|
np.random.seed(seed) |
|
|
random.seed(seed) |
|
|
|
|
|
cudnn.benchmark = True |
|
|
|
|
|
output_dir = Path(args.output_dir) |
|
|
if utils.is_main_process(): |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
with (output_dir / "log.txt").open("a") as f: |
|
|
f.write(" ".join(sys.argv) + "\n") |
|
|
|
|
|
|
|
|
|
|
|
num_tasks = utils.get_world_size() |
|
|
global_rank = utils.get_rank() |
|
|
data_loader_train, data_loader_val = get_loaders_withGlobalID(args, num_tasks, global_rank) |
|
|
|
|
|
|
|
|
|
|
|
mixup_fn = None |
|
|
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None |
|
|
if mixup_active: |
|
|
mixup_fn = Mixup( |
|
|
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, |
|
|
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, |
|
|
label_smoothing=args.smoothing, num_classes=args.nClsEpisode) |
|
|
|
|
|
|
|
|
|
|
|
print(f"Creating model: ProtoNet {args.arch}") |
|
|
model = get_model(backbone = 'vit_small', classifier='protonet', styleAdv=True) |
|
|
|
|
|
model.to(device) |
|
|
|
|
|
model_ema = None |
|
|
if args.model_ema: |
|
|
|
|
|
model_ema = ModelEma( |
|
|
model, |
|
|
decay=args.model_ema_decay, |
|
|
device='cpu' if args.model_ema_force_cpu else '', |
|
|
resume='') |
|
|
|
|
|
model_without_ddp = model |
|
|
if args.distributed: |
|
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], |
|
|
find_unused_parameters=args.unused_params) |
|
|
model_without_ddp = model.module |
|
|
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
print('number of params:', n_parameters) |
|
|
|
|
|
|
|
|
|
|
|
if args.fp16: |
|
|
scale = 1 / 8 |
|
|
linear_scaled_lr = args.lr * utils.get_world_size() * scale |
|
|
args.lr = linear_scaled_lr |
|
|
|
|
|
loss_scaler = NativeScaler() if args.fp16 else None |
|
|
|
|
|
|
|
|
''' |
|
|
optimizer = torch.optim.SGD( |
|
|
[p for p in model_without_ddp.parameters() if p.requires_grad], |
|
|
args.lr, |
|
|
momentum=args.momentum, |
|
|
weight_decay=0, # no weight decay for fine-tuning |
|
|
) |
|
|
''' |
|
|
optimizer = torch.optim.SGD( |
|
|
[ {'params': p for p in model_without_ddp.feature.parameters() if p.requires_grad}, |
|
|
{'params': model_without_ddp.classifier.parameters(), 'lr': lr_classifier}], |
|
|
args.lr, |
|
|
momentum=args.momentum, |
|
|
weight_decay=0, |
|
|
) |
|
|
lr_scheduler, _ = create_scheduler(args, optimizer) |
|
|
|
|
|
if args.mixup > 0.: |
|
|
|
|
|
criterion = SoftTargetCrossEntropy() |
|
|
elif args.smoothing: |
|
|
criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) |
|
|
else: |
|
|
criterion = torch.nn.CrossEntropyLoss() |
|
|
|
|
|
|
|
|
|
|
|
if args.resume: |
|
|
if args.resume.startswith('https'): |
|
|
checkpoint = torch.hub.load_state_dict_from_url( |
|
|
args.resume, map_location='cpu', check_hash=True) |
|
|
else: |
|
|
checkpoint = torch.load(args.resume, map_location='cpu') |
|
|
|
|
|
model_without_ddp.load_state_dict(checkpoint['model']) |
|
|
|
|
|
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: |
|
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
|
lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) |
|
|
args.start_epoch = checkpoint['epoch'] + 1 |
|
|
if args.model_ema: |
|
|
utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) |
|
|
if 'scaler' in checkpoint: |
|
|
loss_scaler.load_state_dict(checkpoint['scaler']) |
|
|
|
|
|
print(f'Resume from {args.resume} at epoch {args.start_epoch}.') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_stats = evaluate(data_loader_val, model, criterion, device, args.seed+10000) |
|
|
print(f"Accuracy of the network on dataset_val: {test_stats['acc1']:.1f}%") |
|
|
if args.output_dir and utils.is_main_process(): |
|
|
test_stats['epoch'] = -1 |
|
|
with (output_dir / "log.txt").open("a") as f: |
|
|
f.write(json.dumps(test_stats) + "\n") |
|
|
|
|
|
if args.eval: |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
if utils.is_main_process(): |
|
|
writer = SummaryWriter(log_dir=str(output_dir)) |
|
|
else: |
|
|
writer = None |
|
|
|
|
|
print(f"Start training for {args.epochs} epochs") |
|
|
start_time = time.time() |
|
|
|
|
|
max_accuracy = 0.0 |
|
|
|
|
|
for epoch in range(args.start_epoch, args.epochs): |
|
|
print('args.start_epoch:', args.start_epoch, 'args.epochs:', args.epochs, 'tmp epoch:', epoch) |
|
|
train_stats = train_one_epoch_styleAdv( |
|
|
data_loader_train, model, criterion, optimizer, epoch, device, |
|
|
loss_scaler, args.fp16, args.clip_grad, model_ema, mixup_fn, writer, |
|
|
set_training_mode=False |
|
|
) |
|
|
|
|
|
lr_scheduler.step(epoch) |
|
|
|
|
|
test_stats = evaluate(data_loader_val, model, criterion, device, args.seed+10000) |
|
|
|
|
|
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, |
|
|
**{f'test_{k}': v for k, v in test_stats.items()}, |
|
|
'epoch': epoch, |
|
|
'n_parameters': n_parameters} |
|
|
|
|
|
if args.output_dir: |
|
|
checkpoint_paths = [output_dir / 'checkpoint.pth', output_dir / 'best.pth'] |
|
|
for checkpoint_path in checkpoint_paths: |
|
|
state_dict = { |
|
|
'model': model_without_ddp.state_dict(), |
|
|
'optimizer': optimizer.state_dict(), |
|
|
'lr_scheduler': lr_scheduler.state_dict(), |
|
|
'epoch': epoch, |
|
|
'model_ema': get_state_dict(model_ema) if args.model_ema else None, |
|
|
'args': args, |
|
|
} |
|
|
if loss_scaler is not None: |
|
|
state_dict['scalar'] = loss_scaler.state_dict() |
|
|
utils.save_on_master(state_dict, checkpoint_path) |
|
|
|
|
|
if test_stats["acc1"] <= max_accuracy: |
|
|
break |
|
|
|
|
|
print(f"Accuracy of the network on dataset_val: {test_stats['acc1']:.1f}%") |
|
|
max_accuracy = max(max_accuracy, test_stats["acc1"]) |
|
|
print(f'Max accuracy: {max_accuracy:.2f}%') |
|
|
|
|
|
if args.output_dir and utils.is_main_process(): |
|
|
log_stats['best_test_acc'] = max_accuracy |
|
|
with (output_dir / "log.txt").open("a") as f: |
|
|
f.write(json.dumps(log_stats) + "\n") |
|
|
|
|
|
total_time = time.time() - start_time |
|
|
total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
|
|
print('Training time {}'.format(total_time_str)) |
|
|
|
|
|
if utils.is_main_process(): |
|
|
writer.close() |
|
|
import tables |
|
|
tables.file._open_files.close_all() |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = get_args_parser() |
|
|
args = parser.parse_args() |
|
|
|
|
|
main(args) |
|
|
|