import sys import datetime import random import numpy as np import time import torch import torch.backends.cudnn as cudnn import json from pathlib import Path from torch.utils.tensorboard import SummaryWriter from timm.data import Mixup from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy from timm.scheduler import create_scheduler from timm.optim import create_optimizer from timm.utils import NativeScaler, get_state_dict, ModelEma #from models.pmf_engine import train_one_epoch, evaluate #from models.pmf_engine_styleAdv_20221102 import train_one_epoch_styleAdv, evaluate #from methods.pmf_engine_styleAdv_20221102 import train_one_epoch_styleAdv, evaluate from methods.engine_StyleAdv_ViT import train_one_epoch_styleAdv, evaluate #import pmf_utils.deit_util as utils import utils.deit_util as utils #from pmf_datasets import get_loaders #from pmf_datasets import get_loaders_withGlobalID from data.pmf_datasets import get_loaders_withGlobalID #from pmf_utils.args import get_args_parser from utils.args import get_args_parser #from models import get_model #from methods.cvpr2023_load_models_20221102 import get_model from methods.load_ViT_models import get_model #lr_classifier = 5e-5 #lr_classifier = 0.01 lr_classifier = 0.001 #lr_classifier = 0.0001 def main(args): utils.init_distributed_mode(args) print(args) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() args.seed = seed torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) cudnn.benchmark = True output_dir = Path(args.output_dir) if utils.is_main_process(): output_dir.mkdir(parents=True, exist_ok=True) with (output_dir / "log.txt").open("a") as f: f.write(" ".join(sys.argv) + "\n") ############################################## # Data loaders num_tasks = utils.get_world_size() global_rank = utils.get_rank() data_loader_train, data_loader_val = get_loaders_withGlobalID(args, num_tasks, global_rank) ############################################## # Mixup regularization (by default OFF) mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_fn = Mixup( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.nClsEpisode) ############################################## # Model print(f"Creating model: ProtoNet {args.arch}") model = get_model(backbone = 'vit_small', classifier='protonet', styleAdv=True) #model = get_model(args) model.to(device) model_ema = None # (by default OFF) if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEma( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume='') model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=args.unused_params) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) ############################################## # Optimizer & scheduler & criterion if args.fp16: scale = 1 / 8 # the default lr is for 8 GPUs linear_scaled_lr = args.lr * utils.get_world_size() * scale args.lr = linear_scaled_lr loss_scaler = NativeScaler() if args.fp16 else None #optimizer = create_optimizer(args, model_without_ddp) ''' optimizer = torch.optim.SGD( [p for p in model_without_ddp.parameters() if p.requires_grad], args.lr, momentum=args.momentum, weight_decay=0, # no weight decay for fine-tuning ) ''' optimizer = torch.optim.SGD( [ {'params': p for p in model_without_ddp.feature.parameters() if p.requires_grad}, {'params': model_without_ddp.classifier.parameters(), 'lr': lr_classifier}], args.lr, momentum=args.momentum, weight_decay=0, # no weight decay for fine-tuning ) lr_scheduler, _ = create_scheduler(args, optimizer) if args.mixup > 0.: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif args.smoothing: criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: criterion = torch.nn.CrossEntropyLoss() ############################################## # Resume training from ckpt (model, optimizer, lr_scheduler, epoch, model_ema, scaler) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if args.model_ema: utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) print(f'Resume from {args.resume} at epoch {args.start_epoch}.') ############################################## # Test test_stats = evaluate(data_loader_val, model, criterion, device, args.seed+10000) print(f"Accuracy of the network on dataset_val: {test_stats['acc1']:.1f}%") if args.output_dir and utils.is_main_process(): test_stats['epoch'] = -1 with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(test_stats) + "\n") if args.eval: return ############################################## # Training if utils.is_main_process(): writer = SummaryWriter(log_dir=str(output_dir)) else: writer = None print(f"Start training for {args.epochs} epochs") start_time = time.time() #max_accuracy = test_stats['acc1'] max_accuracy = 0.0 for epoch in range(args.start_epoch, args.epochs): print('args.start_epoch:', args.start_epoch, 'args.epochs:', args.epochs, 'tmp epoch:', epoch) train_stats = train_one_epoch_styleAdv( data_loader_train, model, criterion, optimizer, epoch, device, loss_scaler, args.fp16, args.clip_grad, model_ema, mixup_fn, writer, set_training_mode=False # TODO: may need eval mode for finetuning ) lr_scheduler.step(epoch) test_stats = evaluate(data_loader_val, model, criterion, device, args.seed+10000) log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters} if args.output_dir: checkpoint_paths = [output_dir / 'checkpoint.pth', output_dir / 'best.pth'] for checkpoint_path in checkpoint_paths: state_dict = { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'model_ema': get_state_dict(model_ema) if args.model_ema else None, 'args': args, } if loss_scaler is not None: state_dict['scalar'] = loss_scaler.state_dict() utils.save_on_master(state_dict, checkpoint_path) if test_stats["acc1"] <= max_accuracy: break # do not save best.pth print(f"Accuracy of the network on dataset_val: {test_stats['acc1']:.1f}%") max_accuracy = max(max_accuracy, test_stats["acc1"]) print(f'Max accuracy: {max_accuracy:.2f}%') if args.output_dir and utils.is_main_process(): log_stats['best_test_acc'] = max_accuracy with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str)) if utils.is_main_process(): writer.close() import tables tables.file._open_files.close_all() if __name__ == '__main__': parser = get_args_parser() args = parser.parse_args() main(args)