| |
| |
| |
| |
| |
|
|
| from pathlib import Path |
| import argparse |
| import json |
| import os |
| import random |
| import signal |
| import sys |
| import time |
| import urllib |
|
|
| from torch import nn, optim |
| from torchvision import models, datasets, transforms |
| import torch |
| import torchvision |
| import wandb |
|
|
| parser = argparse.ArgumentParser(description='Evaluate resnet50 features on ImageNet') |
| parser.add_argument('data', type=Path, metavar='DIR', |
| help='path to dataset') |
| parser.add_argument('pretrained', type=Path, metavar='FILE', |
| help='path to pretrained model') |
| parser.add_argument('--weights', default='freeze', type=str, |
| choices=('finetune', 'freeze'), |
| help='finetune or freeze resnet weights') |
| parser.add_argument('--train-percent', default=100, type=int, |
| choices=(100, 10, 1), |
| help='size of traing set in percent') |
| parser.add_argument('--workers', default=8, type=int, metavar='N', |
| help='number of data loader workers') |
| parser.add_argument('--epochs', default=100, type=int, metavar='N', |
| help='number of total epochs to run') |
| parser.add_argument('--batch-size', default=256, type=int, metavar='N', |
| help='mini-batch size') |
| parser.add_argument('--lr-backbone', default=0.0, type=float, metavar='LR', |
| help='backbone base learning rate') |
| parser.add_argument('--lr-classifier', default=0.3, type=float, metavar='LR', |
| help='classifier base learning rate') |
| parser.add_argument('--weight-decay', default=1e-6, type=float, metavar='W', |
| help='weight decay') |
| parser.add_argument('--print-freq', default=100, type=int, metavar='N', |
| help='print frequency') |
| parser.add_argument('--checkpoint-dir', default='/mnt/store/wbandar1/projects/ssl-aug-artifacts/', type=Path, |
| metavar='DIR', help='path to checkpoint directory') |
|
|
|
|
| def main(): |
| args = parser.parse_args() |
| if args.train_percent in {1, 10}: |
| args.train_files = urllib.request.urlopen(f'https://raw.githubusercontent.com/google-research/simclr/master/imagenet_subsets/{args.train_percent}percent.txt').readlines() |
| args.ngpus_per_node = torch.cuda.device_count() |
| if 'SLURM_JOB_ID' in os.environ: |
| signal.signal(signal.SIGUSR1, handle_sigusr1) |
| signal.signal(signal.SIGTERM, handle_sigterm) |
| |
| args.rank = 0 |
| args.dist_url = f'tcp://localhost:{random.randrange(49152, 65535)}' |
| args.world_size = args.ngpus_per_node |
| torch.multiprocessing.spawn(main_worker, (args,), args.ngpus_per_node) |
|
|
|
|
| def main_worker(gpu, args): |
| args.rank += gpu |
| torch.distributed.init_process_group( |
| backend='nccl', init_method=args.dist_url, |
| world_size=args.world_size, rank=args.rank) |
|
|
| |
| if args.rank == 0: |
| run = wandb.init(project="bt-in1k-eval", config=args, dir='/mnt/store/wbandar1/projects/ssl-aug-artifacts/wandb_logs/') |
| run_id = wandb.run.id |
| args.checkpoint_dir=Path(os.path.join(args.checkpoint_dir, run_id)) |
|
|
| if args.rank == 0: |
| args.checkpoint_dir.mkdir(parents=True, exist_ok=True) |
| stats_file = open(args.checkpoint_dir / 'stats.txt', 'a', buffering=1) |
| print(' '.join(sys.argv)) |
| print(' '.join(sys.argv), file=stats_file) |
|
|
| torch.cuda.set_device(gpu) |
| torch.backends.cudnn.benchmark = True |
|
|
| model = models.resnet50().cuda(gpu) |
| state_dict = torch.load(args.pretrained, map_location='cpu') |
| missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) |
| assert missing_keys == ['fc.weight', 'fc.bias'] and unexpected_keys == [] |
| model.fc.weight.data.normal_(mean=0.0, std=0.01) |
| model.fc.bias.data.zero_() |
| if args.weights == 'freeze': |
| model.requires_grad_(False) |
| model.fc.requires_grad_(True) |
| classifier_parameters, model_parameters = [], [] |
| for name, param in model.named_parameters(): |
| if name in {'fc.weight', 'fc.bias'}: |
| classifier_parameters.append(param) |
| else: |
| model_parameters.append(param) |
| model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu]) |
|
|
| criterion = nn.CrossEntropyLoss().cuda(gpu) |
|
|
| param_groups = [dict(params=classifier_parameters, lr=args.lr_classifier)] |
| if args.weights == 'finetune': |
| param_groups.append(dict(params=model_parameters, lr=args.lr_backbone)) |
| optimizer = optim.SGD(param_groups, 0, momentum=0.9, weight_decay=args.weight_decay) |
| scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs) |
|
|
| |
| if (args.checkpoint_dir / 'checkpoint.pth').is_file(): |
| ckpt = torch.load(args.checkpoint_dir / 'checkpoint.pth', |
| map_location='cpu') |
| start_epoch = ckpt['epoch'] |
| best_acc = ckpt['best_acc'] |
| model.load_state_dict(ckpt['model']) |
| optimizer.load_state_dict(ckpt['optimizer']) |
| scheduler.load_state_dict(ckpt['scheduler']) |
| else: |
| start_epoch = 0 |
| best_acc = argparse.Namespace(top1=0, top5=0) |
|
|
| |
| traindir = args.data / 'train' |
| valdir = args.data / 'val' |
| normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225]) |
|
|
| train_dataset = datasets.ImageFolder(traindir, transforms.Compose([ |
| transforms.RandomResizedCrop(224), |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| normalize, |
| ])) |
| val_dataset = datasets.ImageFolder(valdir, transforms.Compose([ |
| transforms.Resize(256), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| normalize, |
| ])) |
|
|
| if args.train_percent in {1, 10}: |
| train_dataset.samples = [] |
| for fname in args.train_files: |
| fname = fname.decode().strip() |
| cls = fname.split('_')[0] |
| train_dataset.samples.append( |
| (traindir / cls / fname, train_dataset.class_to_idx[cls])) |
|
|
| train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) |
| kwargs = dict(batch_size=args.batch_size // args.world_size, num_workers=args.workers, pin_memory=True) |
| train_loader = torch.utils.data.DataLoader(train_dataset, sampler=train_sampler, **kwargs) |
| val_loader = torch.utils.data.DataLoader(val_dataset, **kwargs) |
|
|
| start_time = time.time() |
| for epoch in range(start_epoch, args.epochs): |
| |
| if args.weights == 'finetune': |
| model.train() |
| elif args.weights == 'freeze': |
| model.eval() |
| else: |
| assert False |
| train_sampler.set_epoch(epoch) |
| for step, (images, target) in enumerate(train_loader, start=epoch * len(train_loader)): |
| output = model(images.cuda(gpu, non_blocking=True)) |
| loss = criterion(output, target.cuda(gpu, non_blocking=True)) |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| if step % args.print_freq == 0: |
| torch.distributed.reduce(loss.div_(args.world_size), 0) |
| if args.rank == 0: |
| pg = optimizer.param_groups |
| lr_classifier = pg[0]['lr'] |
| lr_backbone = pg[1]['lr'] if len(pg) == 2 else 0 |
| stats = dict(epoch=epoch, step=step, lr_backbone=lr_backbone, |
| lr_classifier=lr_classifier, loss=loss.item(), |
| time=int(time.time() - start_time)) |
| print(json.dumps(stats)) |
| print(json.dumps(stats), file=stats_file) |
| run.log( |
| { |
| "epoch": epoch, |
| "step": step, |
| "lr_backbone": lr_backbone, |
| "lr_classifier": lr_classifier, |
| "loss": loss.item(), |
| "time": int(time.time() - start_time), |
| } |
| ) |
|
|
| |
| model.eval() |
| if args.rank == 0: |
| top1 = AverageMeter('Acc@1') |
| top5 = AverageMeter('Acc@5') |
| with torch.no_grad(): |
| for images, target in val_loader: |
| output = model(images.cuda(gpu, non_blocking=True)) |
| acc1, acc5 = accuracy(output, target.cuda(gpu, non_blocking=True), topk=(1, 5)) |
| top1.update(acc1[0].item(), images.size(0)) |
| top5.update(acc5[0].item(), images.size(0)) |
| best_acc.top1 = max(best_acc.top1, top1.avg) |
| best_acc.top5 = max(best_acc.top5, top5.avg) |
| stats = dict(epoch=epoch, acc1=top1.avg, acc5=top5.avg, best_acc1=best_acc.top1, best_acc5=best_acc.top5) |
| print(json.dumps(stats)) |
| print(json.dumps(stats), file=stats_file) |
| run.log( |
| { |
| "epoch": epoch, |
| "eval_acc1": top1.avg, |
| "eval_acc5": top5.avg, |
| "eval_best_acc1": best_acc.top1, |
| "eval_best_acc5": best_acc.top5, |
| } |
| ) |
|
|
| |
| if args.weights == 'freeze': |
| reference_state_dict = torch.load(args.pretrained, map_location='cpu') |
| model_state_dict = model.module.state_dict() |
| for k in reference_state_dict: |
| assert torch.equal(model_state_dict[k].cpu(), reference_state_dict[k]), k |
|
|
| scheduler.step() |
| if args.rank == 0: |
| state = dict( |
| epoch=epoch + 1, best_acc=best_acc, model=model.state_dict(), |
| optimizer=optimizer.state_dict(), scheduler=scheduler.state_dict()) |
| torch.save(state, args.checkpoint_dir / 'checkpoint.pth') |
| wandb.finish() |
|
|
|
|
| def handle_sigusr1(signum, frame): |
| os.system(f'scontrol requeue {os.getenv("SLURM_JOB_ID")}') |
| exit() |
|
|
|
|
| def handle_sigterm(signum, frame): |
| pass |
|
|
|
|
| class AverageMeter(object): |
| """Computes and stores the average and current value""" |
| def __init__(self, name, fmt=':f'): |
| self.name = name |
| self.fmt = fmt |
| self.reset() |
|
|
| def reset(self): |
| self.val = 0 |
| self.avg = 0 |
| self.sum = 0 |
| self.count = 0 |
|
|
| def update(self, val, n=1): |
| self.val = val |
| self.sum += val * n |
| self.count += n |
| self.avg = self.sum / self.count |
|
|
| def __str__(self): |
| fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' |
| return fmtstr.format(**self.__dict__) |
|
|
|
|
| def accuracy(output, target, topk=(1,)): |
| """Computes the accuracy over the k top predictions for the specified values of k""" |
| with torch.no_grad(): |
| maxk = max(topk) |
| batch_size = target.size(0) |
|
|
| _, pred = output.topk(maxk, 1, True, True) |
| pred = pred.t() |
| correct = pred.eq(target.view(1, -1).expand_as(pred)) |
|
|
| res = [] |
| for k in topk: |
| correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) |
| res.append(correct_k.mul_(100.0 / batch_size)) |
| return res |
|
|
|
|
| if __name__ == '__main__': |
| main() |