Spaces:
Build error
Build error
| # Modified from: | |
| # https://github.com/anibali/pytorch-stacked-hourglass | |
| # https://github.com/bearpaw/pytorch-pose | |
| import os | |
| import shutil | |
| import scipy.io | |
| import torch | |
| def to_numpy(tensor): | |
| if torch.is_tensor(tensor): | |
| return tensor.detach().cpu().numpy() | |
| elif type(tensor).__module__ != 'numpy': | |
| raise ValueError("Cannot convert {} to numpy array" | |
| .format(type(tensor))) | |
| return tensor | |
| def to_torch(ndarray): | |
| if type(ndarray).__module__ == 'numpy': | |
| return torch.from_numpy(ndarray) | |
| elif not torch.is_tensor(ndarray): | |
| raise ValueError("Cannot convert {} to torch tensor" | |
| .format(type(ndarray))) | |
| return ndarray | |
| def save_checkpoint(state, preds, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar', snapshot=None): | |
| preds = to_numpy(preds) | |
| filepath = os.path.join(checkpoint, filename) | |
| torch.save(state, filepath) | |
| scipy.io.savemat(os.path.join(checkpoint, 'preds.mat'), mdict={'preds' : preds}) | |
| if snapshot and state['epoch'] % snapshot == 0: | |
| shutil.copyfile(filepath, os.path.join(checkpoint, 'checkpoint_{}.pth.tar'.format(state['epoch']))) | |
| if is_best: | |
| shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar')) | |
| scipy.io.savemat(os.path.join(checkpoint, 'preds_best.mat'), mdict={'preds' : preds}) | |
| def save_pred(preds, checkpoint='checkpoint', filename='preds_valid.mat'): | |
| preds = to_numpy(preds) | |
| filepath = os.path.join(checkpoint, filename) | |
| scipy.io.savemat(filepath, mdict={'preds' : preds}) | |
| def adjust_learning_rate(optimizer, epoch, lr, schedule, gamma): | |
| """Sets the learning rate to the initial LR decayed by schedule""" | |
| if epoch in schedule: | |
| lr *= gamma | |
| for param_group in optimizer.param_groups: | |
| param_group['lr'] = lr | |
| return lr | |