from __future__ import print_function import argparse import os import shutil import time import random import numpy as np import torch import torch.nn as nn import torch.nn.parallel import torch.backends.cudnn as cudnn import torch.optim as optim import torch.utils.data as data import torchvision.transforms as transforms import torch.nn.functional as F import models.wideresnet as models import dataset.cifar10 as dataset from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p, savefig from tensorboardX import SummaryWriter parser = argparse.ArgumentParser(description='PyTorch MixMatch Training') # Optimization options parser.add_argument('--epochs', default=1024, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)') parser.add_argument('--batch-size', default=64, type=int, metavar='N', help='train batchsize') parser.add_argument('--lr', '--learning-rate', default=0.002, type=float, metavar='LR', help='initial learning rate') # Checkpoints parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') # Miscs parser.add_argument('--manualSeed', type=int, default=0, help='manual seed') #Device options parser.add_argument('--gpu', default='0', type=str, help='id(s) for CUDA_VISIBLE_DEVICES') #Method options parser.add_argument('--n-labeled', type=int, default=250, help='Number of labeled data') parser.add_argument('--train-iteration', type=int, default=1024, help='Number of iteration per epoch') parser.add_argument('--out', default='result', help='Directory to output the result') parser.add_argument('--alpha', default=0.75, type=float) parser.add_argument('--lambda-u', default=75, type=float) parser.add_argument('--T', default=0.5, type=float) parser.add_argument('--ema-decay', default=0.999, type=float) args = parser.parse_args() state = {k: v for k, v in args._get_kwargs()} # Use CUDA os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu use_cuda = torch.cuda.is_available() # Random seed if args.manualSeed is None: args.manualSeed = random.randint(1, 10000) np.random.seed(args.manualSeed) best_acc = 0 # best test accuracy def main(): global best_acc if not os.path.isdir(args.out): mkdir_p(args.out) # Data print(f'==> Preparing cifar10') transform_train = transforms.Compose([ dataset.RandomPadandCrop(32), dataset.RandomFlip(), dataset.ToTensor(), ]) transform_val = transforms.Compose([ dataset.ToTensor(), ]) train_labeled_set, train_unlabeled_set, val_set, test_set = dataset.get_cifar10('./data', args.n_labeled, transform_train=transform_train, transform_val=transform_val) labeled_trainloader = data.DataLoader(train_labeled_set, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True) unlabeled_trainloader = data.DataLoader(train_unlabeled_set, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True) val_loader = data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=0) test_loader = data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=0) print(1) # Model print("==> creating WRN-28-2") def create_model(ema=False): model = models.WideResNet(num_classes=10) model = model.cuda() if ema: for param in model.parameters(): param.detach_() return model model = create_model() ema_model = create_model(ema=True) cudnn.benchmark = True print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0)) train_criterion = SemiLoss() criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=args.lr) ema_optimizer= WeightEMA(model, ema_model, alpha=args.ema_decay) start_epoch = 0 # Resume title = 'noisy-cifar-10' if args.resume: # Load checkpoint. print('==> Resuming from checkpoint..') assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!' args.out = os.path.dirname(args.resume) checkpoint = torch.load(args.resume) best_acc = checkpoint['best_acc'] start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) ema_model.load_state_dict(checkpoint['ema_state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) logger = Logger(os.path.join(args.out, 'log.txt'), title=title, resume=True) else: logger = Logger(os.path.join(args.out, 'log.txt'), title=title) logger.set_names(['Train Loss', 'Train Loss X', 'Train Loss U', 'Valid Loss', 'Valid Acc.', 'Test Loss', 'Test Acc.']) writer = SummaryWriter(args.out) step = 0 test_accs = [] # Train and val for epoch in range(start_epoch, args.epochs): print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, state['lr'])) train_loss, train_loss_x, train_loss_u = train(labeled_trainloader, unlabeled_trainloader, model, optimizer, ema_optimizer, train_criterion, epoch, use_cuda) _, train_acc = validate(labeled_trainloader, ema_model, criterion, epoch, use_cuda, mode='Train Stats') val_loss, val_acc = validate(val_loader, ema_model, criterion, epoch, use_cuda, mode='Valid Stats') test_loss, test_acc = validate(test_loader, ema_model, criterion, epoch, use_cuda, mode='Test Stats ') step = args.train_iteration * (epoch + 1) writer.add_scalar('losses/train_loss', train_loss, step) writer.add_scalar('losses/valid_loss', val_loss, step) writer.add_scalar('losses/test_loss', test_loss, step) writer.add_scalar('accuracy/train_acc', train_acc, step) writer.add_scalar('accuracy/val_acc', val_acc, step) writer.add_scalar('accuracy/test_acc', test_acc, step) # append logger file logger.append([train_loss, train_loss_x, train_loss_u, val_loss, val_acc, test_loss, test_acc]) # save model is_best = val_acc > best_acc best_acc = max(val_acc, best_acc) save_checkpoint({ 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'ema_state_dict': ema_model.state_dict(), 'acc': val_acc, 'best_acc': best_acc, 'optimizer' : optimizer.state_dict(), }, is_best) test_accs.append(test_acc) logger.close() writer.close() print('Best acc:') print(best_acc) print('Mean acc:') print(np.mean(test_accs[-20:])) def train(labeled_trainloader, unlabeled_trainloader, model, optimizer, ema_optimizer, criterion, epoch, use_cuda): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() losses_x = AverageMeter() losses_u = AverageMeter() ws = AverageMeter() end = time.time() bar = Bar('Training', max=args.train_iteration) labeled_train_iter = iter(labeled_trainloader) unlabeled_train_iter = iter(unlabeled_trainloader) model.train() for batch_idx in range(args.train_iteration): try: # inputs_x, targets_x = labeled_train_iter.next() inputs_x, targets_x = next(labeled_train_iter) except: labeled_train_iter = iter(labeled_trainloader) # inputs_x, targets_x = labeled_train_iter.next() inputs_x, targets_x = next(labeled_train_iter) try: # (inputs_u, inputs_u2), _ = unlabeled_train_iter.next() (inputs_u, inputs_u2), _ = next(unlabeled_train_iter) except: unlabeled_train_iter = iter(unlabeled_trainloader) # (inputs_u, inputs_u2), _ = unlabeled_train_iter.next() (inputs_u, inputs_u2), _ = next(unlabeled_train_iter) # measure data loading time data_time.update(time.time() - end) batch_size = inputs_x.size(0) # Transform label to one-hot targets_x = torch.zeros(batch_size, 10).scatter_(1, targets_x.view(-1,1).long(), 1) if use_cuda: inputs_x, targets_x = inputs_x.cuda(), targets_x.cuda(non_blocking=True) # inputs = inputs_x.cuda() # targets = targets.type(torch.LongTensor).cuda(non_blocking=True) inputs_u = inputs_u.cuda() inputs_u2 = inputs_u2.cuda() with torch.no_grad(): # compute guessed labels of unlabel samples outputs_u = model(inputs_u) outputs_u2 = model(inputs_u2) p = (torch.softmax(outputs_u, dim=1) + torch.softmax(outputs_u2, dim=1)) / 2 pt = p**(1/args.T) targets_u = pt / pt.sum(dim=1, keepdim=True) targets_u = targets_u.detach() # mixup all_inputs = torch.cat([inputs_x, inputs_u, inputs_u2], dim=0) all_targets = torch.cat([targets_x, targets_u, targets_u], dim=0) l = np.random.beta(args.alpha, args.alpha) l = max(l, 1-l) idx = torch.randperm(all_inputs.size(0)) input_a, input_b = all_inputs, all_inputs[idx] target_a, target_b = all_targets, all_targets[idx] mixed_input = l * input_a + (1 - l) * input_b mixed_target = l * target_a + (1 - l) * target_b # interleave labeled and unlabed samples between batches to get correct batchnorm calculation mixed_input = list(torch.split(mixed_input, batch_size)) mixed_input = interleave(mixed_input, batch_size) logits = [model(mixed_input[0])] for input in mixed_input[1:]: logits.append(model(input)) # put interleaved samples back logits = interleave(logits, batch_size) logits_x = logits[0] logits_u = torch.cat(logits[1:], dim=0) Lx, Lu, w = criterion(logits_x, mixed_target[:batch_size], logits_u, mixed_target[batch_size:], epoch+batch_idx/args.train_iteration) loss = Lx + w * Lu # record loss losses.update(loss.item(), inputs_x.size(0)) losses_x.update(Lx.item(), inputs_x.size(0)) losses_u.update(Lu.item(), inputs_x.size(0)) ws.update(w, inputs_x.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() ema_optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() # plot progress bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | Loss_x: {loss_x:.4f} | Loss_u: {loss_u:.4f} | W: {w:.4f}'.format( batch=batch_idx + 1, size=args.train_iteration, data=data_time.avg, bt=batch_time.avg, total=bar.elapsed_td, eta=bar.eta_td, loss=losses.avg, loss_x=losses_x.avg, loss_u=losses_u.avg, w=ws.avg, ) bar.next() bar.finish() return (losses.avg, losses_x.avg, losses_u.avg,) def validate(valloader, model, criterion, epoch, use_cuda, mode): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to evaluate mode model.eval() end = time.time() bar = Bar(f'{mode}', max=len(valloader)) with torch.no_grad(): for batch_idx, (inputs, targets) in enumerate(valloader): # measure data loading time data_time.update(time.time() - end) if use_cuda: inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True) # compute output outputs = model(inputs) loss = criterion(outputs, targets) # measure accuracy and record loss prec1, prec5 = accuracy(outputs, targets, topk=(1, 5)) losses.update(loss.item(), inputs.size(0)) top1.update(prec1.item(), inputs.size(0)) top5.update(prec5.item(), inputs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() # plot progress bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( batch=batch_idx + 1, size=len(valloader), data=data_time.avg, bt=batch_time.avg, total=bar.elapsed_td, eta=bar.eta_td, loss=losses.avg, top1=top1.avg, top5=top5.avg, ) bar.next() bar.finish() return (losses.avg, top1.avg) def save_checkpoint(state, is_best, checkpoint=args.out, filename='checkpoint.pth.tar'): filepath = os.path.join(checkpoint, filename) torch.save(state, filepath) if is_best: shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar')) def linear_rampup(current, rampup_length=args.epochs): if rampup_length == 0: return 1.0 else: current = np.clip(current / rampup_length, 0.0, 1.0) return float(current) class SemiLoss(object): def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch): probs_u = torch.softmax(outputs_u, dim=1) Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1)) Lu = torch.mean((probs_u - targets_u)**2) return Lx, Lu, args.lambda_u * linear_rampup(epoch) class WeightEMA(object): def __init__(self, model, ema_model, alpha=0.999): self.model = model self.ema_model = ema_model self.alpha = alpha self.params = list(model.state_dict().values()) self.ema_params = list(ema_model.state_dict().values()) self.wd = 0.02 * args.lr for param, ema_param in zip(self.params, self.ema_params): param.data.copy_(ema_param.data) def step(self): one_minus_alpha = 1.0 - self.alpha for param, ema_param in zip(self.params, self.ema_params): if ema_param.dtype==torch.float32: ema_param.mul_(self.alpha) ema_param.add_(param * one_minus_alpha) # customized weight decay param.mul_(1 - self.wd) def interleave_offsets(batch, nu): groups = [batch // (nu + 1)] * (nu + 1) for x in range(batch - sum(groups)): groups[-x - 1] += 1 offsets = [0] for g in groups: offsets.append(offsets[-1] + g) assert offsets[-1] == batch return offsets def interleave(xy, batch): nu = len(xy) - 1 offsets = interleave_offsets(batch, nu) xy = [[v[offsets[p]:offsets[p + 1]] for p in range(nu + 1)] for v in xy] for i in range(1, nu + 1): xy[0][i], xy[i][i] = xy[i][i], xy[0][i] return [torch.cat(v, dim=0) for v in xy] if __name__ == '__main__': main()