| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import os |
| import argparse |
| import json |
| from pathlib import Path |
| import time |
| import datetime |
|
|
| import torch |
| from torch import nn |
| import torch.distributed as dist |
| import torch.backends.cudnn as cudnn |
| from torchvision import datasets |
| from torchvision import transforms as pth_transforms |
| from torchvision import models as torchvision_models |
|
|
| import utils |
| import vision_transformer as vits |
|
|
| from torch.utils.tensorboard import SummaryWriter |
| import shutil |
| import itertools |
| import numpy as np |
|
|
| from timm.scheduler import create_scheduler |
| from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy |
| from timm.data import create_transform |
| from timm.data import Mixup |
|
|
| from samplers import RASampler |
| from datasets import build_dataset |
|
|
| def main(args): |
| if args.device != 'cuda': |
| args.distributed = False |
| else: |
| utils.init_distributed_mode(args) |
|
|
| print(args) |
|
|
| |
| seed = args.seed + utils.get_rank() |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
|
|
| device = torch.device(args.device) |
| cudnn.benchmark = True |
|
|
| |
| |
| if args.arch in vits.__dict__.keys(): |
| model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=args.num_labels, adjacency_bp=args.adjacency_bp, temperature=args.temperature) |
| embed_dim = model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens)) |
| else: |
| print(f"Unknow architecture: {args.arch}") |
| sys.exit(1) |
|
|
| model.to(device) |
| model.eval() |
|
|
| model_without_ddp = model |
| |
| |
| |
| |
| n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print('number of params:', n_parameters) |
| |
| utils.load_pretrained_weights(model_without_ddp, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size) |
| print(f"Model {args.arch} built.") |
|
|
| linear_classifier = LinearClassifier(embed_dim, num_labels=args.num_labels) |
| linear_classifier = linear_classifier.cuda() |
| classifier_without_ddp = linear_classifier |
| linear_classifier = nn.parallel.DistributedDataParallel(linear_classifier, device_ids=[args.gpu]) |
|
|
| |
| dataset_train, args.num_labels = build_dataset(is_train = True, args=args) |
| dataset_val, _ = build_dataset(is_train=False, args=args) |
| |
| num_tasks = utils.get_world_size() |
| global_rank = utils.get_rank() |
|
|
| if args.distributed: |
| if args.data_aug and args.repeated_aug: |
| sampler_train = RASampler( |
| dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True |
| ) |
| else: |
| sampler_train = torch.utils.data.distributed.DistributedSampler(dataset_train) |
| sampler_val = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False) |
| else: |
| sampler = torch.utils.data.RandomSampler(dataset_train) |
| sampler_val = torch.utils.data.SequentialSampler(dataset_val) |
|
|
| train_loader = torch.utils.data.DataLoader( |
| dataset_train, |
| sampler=sampler_train, |
| batch_size=args.batch_size_per_gpu, |
| num_workers=args.num_workers, |
| pin_memory=True, |
| ) |
| val_loader = torch.utils.data.DataLoader( |
| dataset_val, |
| sampler=sampler_val, |
| batch_size=args.batch_size_per_gpu, |
| num_workers=args.num_workers, |
| pin_memory=True, |
| ) |
| print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.") |
|
|
| if args.evaluate: |
| checkpoint = torch.load(args.checkpoint, map_location='cpu') |
| model_without_ddp.load_state_dict(checkpoint['model']) |
| test_stats = validate_network(val_loader, model_without_ddp, device) |
| print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") |
| return |
| |
| |
| optimizer = torch.optim.SGD( |
| linear_classifier.parameters(), |
| args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256., |
| momentum=0.9, |
| weight_decay=args.weight_decay, |
| ) |
| scheduler, _ = create_scheduler(args, optimizer) |
|
|
| criterion = nn.CrossEntropyLoss() |
| |
| mixup_fn = None |
| smoothing = None |
| if args.data_aug: |
| print('Data augmentation: Mixup CutMix enable') |
| mixup_fn = Mixup( |
| mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, |
| prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, |
| label_smoothing=args.smoothing, num_classes=args.num_labels) |
| criterion = SoftTargetCrossEntropy() |
|
|
| if utils.is_main_process(): |
| writer = SummaryWriter(args.output_dir + '/log') |
| start_epoch = 0 |
| best_acc = 0 |
| print("Starting training") |
| start_time = time.time() |
| for epoch in range(start_epoch, args.epochs): |
| if args.distributed: |
| train_loader.sampler.set_epoch(epoch) |
|
|
| train_stats = train(model_without_ddp, device, optimizer, train_loader, epoch, criterion, linear_classifier, args.n_last_blocks, args.avgpool_patchtokens, mixup_fn) |
|
|
| scheduler.step(epoch) |
|
|
| log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, |
| 'epoch': epoch} |
| if epoch % args.val_freq == 0 or epoch == args.epochs - 1: |
| test_stats = validate_network(val_loader, model, device, linear_classifier, args.n_last_blocks, args.avgpool_patchtokens) |
| print(f"Accuracy at epoch {epoch} of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") |
| log_stats = {**{k: v for k, v in log_stats.items()}, |
| **{f'test_{k}': v for k, v in test_stats.items()}} |
| if utils.is_main_process(): |
| with (Path(args.output_dir) / "log.txt").open("a") as f: |
| f.write(json.dumps(log_stats) + "\n") |
| save_dict = { |
| "epoch": epoch + 1, |
| "classifier": classifier_without_ddp.state_dict(), |
| "model": model_without_ddp.state_dict(), |
| "optimizer": optimizer.state_dict(), |
| "scheduler": scheduler.state_dict(), |
| "best_acc": best_acc, |
| } |
|
|
| writer.add_scalar('Train_loss', train_stats['loss'], global_step=epoch) |
| writer.add_scalar('Learning_rate', train_stats['lr'], global_step=epoch) |
| writer.add_scalar('Train Acc_1', train_stats['acc1'], global_step=epoch) |
| writer.add_scalar('Acc_1', test_stats['acc1'], global_step=epoch) |
| writer.add_scalar('Acc_5', test_stats['acc5'], global_step=epoch) |
| |
| checkpoint_path = os.path.join(args.output_dir, "checkpoint.pth") |
| torch.save(save_dict, checkpoint_path) |
| |
| if best_acc < float(test_stats['acc1']): |
| best_acc = float(test_stats['acc1']) |
| shutil.copyfile(checkpoint_path, args.output_dir + '/model_best.pth') |
| print(f'Max accuracy so far: {best_acc:.2f}%') |
| print("Training of the TokenCut completed.\n" |
| "Top-1 test accuracy: {acc:.1f}".format(acc=best_acc)) |
| total_time = time.time() - start_time |
| total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
| print(f'Training time {total_time_str}') |
| |
|
|
| def train(model, device, optimizer, loader, epoch, criterion, linear_classifier, n, avgpool, mixup_fn=None,): |
| linear_classifier.train() |
| metric_logger = utils.MetricLogger(delimiter=" ") |
| metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) |
| header = 'Epoch: [{}]'.format(epoch) |
| for batch in metric_logger.log_every(loader, 20, header): |
| inp, target = batch[:2] |
| |
| inp = inp.to(device, non_blocking=True) |
| target = target.to(device, non_blocking=True) |
| hard_target = target.clone() |
| if args.data_aug: |
| inp, target = mixup_fn(inp, target) |
| |
| with torch.no_grad(): |
| intermediate_output,_ = model.get_intermediate_layers(inp, n) |
| output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1) |
| if avgpool: |
| output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1) |
| output = output.reshape(output.shape[0], -1) |
| output = linear_classifier(output) |
|
|
| |
| loss = criterion(output, target) |
| |
| acc1, = utils.accuracy(output, hard_target, topk=(1,)) |
| |
| optimizer.zero_grad() |
| loss.backward() |
|
|
| |
| optimizer.step() |
|
|
| |
| torch.cuda.synchronize() |
| batch_size = inp.shape[0] |
| metric_logger.update(loss=loss.item()) |
| metric_logger.update(lr=optimizer.param_groups[0]["lr"]) |
| metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) |
| |
| metric_logger.synchronize_between_processes() |
| print("Averaged stats:", metric_logger) |
| return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|
|
|
| @torch.no_grad() |
| def validate_network(val_loader, model, device, linear_classifier, n, avgpool): |
| linear_classifier.eval() |
| metric_logger = utils.MetricLogger(delimiter=" ") |
| header = 'Test:' |
| |
| for batch in metric_logger.log_every(val_loader, 20, header): |
| inp, target = batch[:2] |
| |
| inp = inp.to(device, non_blocking=True) |
| target = target.to(device, non_blocking=True) |
|
|
| |
| with torch.no_grad(): |
| intermediate_output,_ = model.get_intermediate_layers(inp, n) |
| output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1) |
| if avgpool: |
| output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1) |
| output = output.reshape(output.shape[0], -1) |
| output = linear_classifier(output) |
| loss = nn.CrossEntropyLoss()(output, target) |
|
|
| if linear_classifier.module.num_labels >= 5: |
| acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) |
| else: |
| acc1, = utils.accuracy(output, target, topk=(1,)) |
|
|
| batch_size = inp.shape[0] |
| metric_logger.update(loss=loss.item()) |
| metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) |
| if linear_classifier.module.num_labels >= 5: |
| metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) |
| if linear_classifier.module.num_labels >= 5: |
| print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' |
| .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) |
| else: |
| print('* Acc@1 {top1.global_avg:.3f} loss {losses.global_avg:.3f}' |
| .format(top1=metric_logger.acc1, losses=metric_logger.loss)) |
|
|
| return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|
| class LinearClassifier(nn.Module): |
| """Linear layer to train on top of frozen features""" |
| def __init__(self, dim, num_labels=1000): |
| super(LinearClassifier, self).__init__() |
| self.num_labels = num_labels |
| self.linear = nn.Linear(dim, num_labels) |
| self.linear.weight.data.normal_(mean=0.0, std=0.01) |
| self.linear.bias.data.zero_() |
|
|
| def forward(self, x): |
| |
| x = x.view(x.size(0), -1) |
|
|
| |
| return self.linear(x) |
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser('Evaluation with linear classification on ImageNet') |
| parser.add_argument('--n_last_blocks', default=4, type=int, help="""Concatenate [CLS] tokens |
| for the `n` last blocks. We use `n=4` when evaluating ViT-Small and `n=1` with ViT-Base.""") |
| parser.add_argument('--avgpool_patchtokens', default=False, type=utils.bool_flag, |
| help="""Whether ot not to concatenate the global average pooled features to the [CLS] token. |
| We typically set this to False for ViT-Small and to True with ViT-Base.""") |
| parser.add_argument('--arch', default='vit_small', choices=['vit_small', 'vit_base'], type=str, help='Architecture') |
| parser.add_argument('--dataset', default='cub', type=str, choices=['cub', 'imagenet'], help='Architecture') |
| parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.') |
| parser.add_argument('--input_size', default=224, type=int, help='Input image size, default(224).') |
| parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.") |
| parser.add_argument("--checkpoint_key", default="teacher", type=str, help='Key to use in the checkpoint (example: "teacher")') |
| parser.add_argument('--epochs', default=100, type=int, help='Number of epochs of training.') |
| parser.add_argument('--batch_size_per_gpu', default=128, type=int, help='Per-GPU batch-size') |
| parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up |
| distributed training; see https://pytorch.org/docs/stable/distributed.html""") |
| parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") |
| parser.add_argument('--data_path', default='/path/to/imagenet/', type=str) |
| parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') |
| parser.add_argument('--val_freq', default=1, type=int, help="Epoch frequency for validation.") |
| parser.add_argument('--output_dir', default="./checkpoints", help='Path to save logs and checkpoints') |
| parser.add_argument('--num_labels', default=1000, type=int, help='Number of labels for linear classifier') |
| parser.add_argument('--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set') |
| parser.add_argument('--weight_decay', default=0.1, type=float, help="weight_decay, default 0.1") |
| parser.add_argument('--device', default='cuda', help='device to use for training / testing') |
| parser.add_argument('--distributed', default=False, action='store_true', help='device to use for training / testing') |
| parser.add_argument('--adjacency_bp', default=False, action='store_true', help='whether backprop from adjacency matrix') |
| parser.add_argument('--temperature', default=1, type=int, help='Temperature for mask') |
| parser.add_argument('--seed', default=0, type=int) |
|
|
| |
| parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', |
| help='LR scheduler (default: "cosine"') |
| parser.add_argument('--lr', type=float, default=1e-4, metavar='LR', |
| help="""Learning rate at the beginning of |
| training (highest LR used during training). The learning rate is linearly scaled |
| with the batch size, and specified here for a reference batch size of 256. |
| We recommend tweaking the LR depending on the checkpoint evaluated.""") |
| parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', |
| help='learning rate noise on/off epoch percentages') |
| parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', |
| help='learning rate noise limit percent (default: 0.67)') |
| parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', |
| help='learning rate noise std-dev (default: 1.0)') |
| parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', |
| help='warmup learning rate (default: 1e-6)') |
| parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', |
| help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') |
| parser.add_argument('--decay-epochs', type=float, default=5, metavar='N', |
| help='epoch interval to decay LR') |
| parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', |
| help='epochs to warmup LR, if scheduler supports') |
| parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', |
| help='epochs to cooldown LR at min_lr, after cyclic schedule ends') |
| parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', |
| help='patience epochs for Plateau LR scheduler (default: 10') |
| parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', |
| help='LR decay rate (default: 0.1)') |
| |
| parser.add_argument('--label-smooth-loss', default=False, action='store_true', help='use label smooth') |
| |
| parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', |
| help='Random erase prob (default: 0.25)') |
| parser.add_argument('--remode', type=str, default='pixel', |
| help='Random erase mode (default: "pixel")') |
| parser.add_argument('--recount', type=int, default=1, |
| help='Random erase count (default: 1)') |
| parser.add_argument('--resplit', action='store_true', default=False, |
| help='Do not random erase first (clean) augmentation split') |
|
|
| |
| parser.add_argument('--mixup', type=float, default=0.8, |
| help='mixup alpha, mixup enabled if > 0. (default: 0.8)') |
| parser.add_argument('--cutmix', type=float, default=1.0, |
| help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') |
| parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, |
| help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') |
| parser.add_argument('--mixup-prob', type=float, default=1.0, |
| help='Probability of performing mixup or cutmix when either/both is enabled') |
| parser.add_argument('--mixup-switch-prob', type=float, default=0.5, |
| help='Probability of switching to cutmix when both mixup and cutmix enabled') |
| parser.add_argument('--mixup-mode', type=str, default='batch', |
| help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') |
|
|
| |
| parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', |
| help='Color jitter factor (default: 0.4)') |
| parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', |
| help='Use AutoAugment policy. "v0" or "original". " + \ |
| "(default: rand-m9-mstd0.5-inc1)'), |
| parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') |
| parser.add_argument('--train-interpolation', type=str, default='bicubic', |
| help='Training interpolation (random, bilinear, bicubic default: "bicubic")') |
|
|
| parser.add_argument('--repeated-aug', action='store_true') |
| parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug') |
| parser.set_defaults(repeated_aug=True) |
|
|
| parser.add_argument('--no_center_crop', default=False, action='store_true', help='Center crop input image') |
| parser.add_argument('--data-aug', action='store_true', default=False, help='disable the data augmentations.') |
| parser.add_argument('--ori_size', default=False, action='store_true', help='Evaluate on image raw size') |
| args = parser.parse_args() |
| main(args) |
|
|