Spaces:
Runtime error
Runtime error
| import os | |
| import logging | |
| import numpy as np | |
| import torch | |
| from torch.nn import init | |
| def init_weights(net, init_type='normal', init_gain=0.02): | |
| """Initialize network weights. | |
| Parameters: | |
| net (network) -- network to be initialized | |
| init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal | |
| init_gain (float) -- scaling factor for normal, xavier and orthogonal. | |
| We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might | |
| work better for some applications. Feel free to try yourself. | |
| """ | |
| def init_func(m): # define the initialization function | |
| classname = m.__class__.__name__ | |
| if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): | |
| if init_type == 'normal': | |
| init.normal_(m.weight.data, 0.0, init_gain) | |
| elif init_type == 'xavier': | |
| init.xavier_normal_(m.weight.data, gain=init_gain) | |
| elif init_type == 'kaiming': | |
| init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') | |
| elif init_type == 'orthogonal': | |
| init.orthogonal_(m.weight.data, gain=init_gain) | |
| else: | |
| raise NotImplementedError('initialization method [%s] is not implemented' % init_type) | |
| if hasattr(m, 'bias') and m.bias is not None: | |
| init.constant_(m.bias.data, 0.0) | |
| elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. | |
| init.normal_(m.weight.data, 1.0, init_gain) | |
| init.constant_(m.bias.data, 0.0) | |
| net.apply(init_func) # apply the initialization function <init_func> | |
| def create_logger(name, log_file, level=logging.INFO): | |
| l = logging.getLogger(name) | |
| formatter = logging.Formatter('[%(asctime)s] %(message)s') | |
| fh = logging.FileHandler(log_file) | |
| fh.setFormatter(formatter) | |
| sh = logging.StreamHandler() | |
| sh.setFormatter(formatter) | |
| l.setLevel(level) | |
| l.addHandler(fh) | |
| l.addHandler(sh) | |
| return l | |
| class AverageMeter(object): | |
| """Computes and stores the average and current value""" | |
| def __init__(self, length=0): | |
| self.length = length | |
| self.reset() | |
| def reset(self): | |
| if self.length > 0: | |
| self.history = [] | |
| else: | |
| self.count = 0 | |
| self.sum = 0.0 | |
| self.val = 0.0 | |
| self.avg = 0.0 | |
| def update(self, val): | |
| if self.length > 0: | |
| self.history.append(val) | |
| if len(self.history) > self.length: | |
| del self.history[0] | |
| self.val = self.history[-1] | |
| self.avg = np.mean(self.history) | |
| else: | |
| self.val = val | |
| self.sum += val | |
| self.count += 1 | |
| self.avg = self.sum / self.count | |
| def accuracy(output, target, topk=(1,)): | |
| """Computes the precision@k for the specified values of k""" | |
| maxk = max(topk) | |
| batch_size = target.size(0) | |
| _, pred = output.topk(maxk, 1, True, True) | |
| pred = pred.t() | |
| correct = pred.eq(target.view(1, -1).expand_as(pred)) | |
| res = [] | |
| for k in topk: | |
| correct_k = correct[:k].view(-1).float().sum(0, keepdims=True) | |
| res.append(correct_k.mul_(100.0 / batch_size)) | |
| return res | |
| def load_state(path, model, optimizer=None): | |
| def map_func(storage, location): | |
| return storage.cuda() | |
| if os.path.isfile(path): | |
| print("=> loading checkpoint '{}'".format(path)) | |
| checkpoint = torch.load(path, map_location=map_func) | |
| model.load_state_dict(checkpoint['state_dict'], strict=False) | |
| ckpt_keys = set(checkpoint['state_dict'].keys()) | |
| own_keys = set(model.state_dict().keys()) | |
| missing_keys = own_keys - ckpt_keys | |
| # print(ckpt_keys) | |
| # print(own_keys) | |
| for k in missing_keys: | |
| print('caution: missing keys from checkpoint {}: {}'.format(path, k)) | |
| last_iter = checkpoint['step'] | |
| if optimizer != None: | |
| optimizer.load_state_dict(checkpoint['optimizer']) | |
| print("=> also loaded optimizer from checkpoint '{}' (iter {})" | |
| .format(path, last_iter)) | |
| return last_iter | |
| else: | |
| print("=> no checkpoint found at '{}'".format(path)) | |