File size: 2,708 Bytes
4c62147 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import numpy as np
import torch
from torchtask.utils import logger
""" This file provides tool functions for deep learning.
"""
def sigmoid_rampup(current, rampup_length):
""" Exponential rampup from https://arxiv.org/abs/1610.02242 .
"""
if rampup_length == 0:
return 1.0
else:
current = np.clip(current, 0.0, rampup_length)
phase = 1.0 - current / rampup_length
return float(np.exp(-5.0 * phase * phase))
def split_tensor_tuple(ttuple, start, end, reduce_dim=False):
""" Slice each tensor in the input tuple by channel-dim.
Arguments:
ttuple (tuple): tuple of a torch.Tensor
start (int): start index of slicing
end (int): end index of slicing
reduce_dim (bool): whether reduce the channel-dim when end - start == 1
Returns:
tuple: a sliced tensor tuple
"""
result = []
if reduce_dim:
assert end - start == 1
for t in ttuple:
if end - start == 1 and reduce_dim:
result.append(t[start, ...])
else:
result.append(t[start:end, ...])
return tuple(result)
def combine_tensor_tuple(ttuple1, ttuple2, dim):
result = []
assert len(ttuple1) == len(ttuple2)
for t1, t2 in zip(ttuple1, ttuple2):
result.append(torch.cat((t1, t2), dim=dim))
return tuple(result)
def create_model(mclass, mname, **kwargs):
""" Create a nn.Module and setup it on multiple GPUs.
"""
model = mclass(**kwargs)
model = torch.nn.DataParallel(model)
model = model.cuda()
logger.log_info(' ' + '=' * 76 + '\n {0} parameters \n{1}'.format(mname, model_str(model)))
return model
def model_str(module):
""" Output model structure and parameters number as strings.
"""
row_format = ' {name:<40} {shape:>20} = {total_size:>12,d}'
lines = [' ' + '-' * 76,]
params = list(module.named_parameters())
for name, param in params:
lines.append(row_format.format(name=name,
shape=' * '.join(str(p) for p in param.size()), total_size=param.numel()))
lines.append(' ' + '-' * 76)
lines.append(row_format.format(name='all parameters', shape='sum of above',
total_size=sum(int(param.numel()) for name, param in params)))
lines.append(' ' + '=' * 76)
lines.append('')
return '\n'.join(lines)
def pytorch_support(required_version='1.0.0', info_str=''):
if torch.__version__ < required_version:
logger.log_err('{0} required PyTorch >= {1}\n'
'However, current PyTorch == {2}\n'
.format(info_str, required_version, torch.__version__))
else:
return True |