Inmental's picture
Upload folder using huggingface_hub
4c62147 verified
import numpy as np
import torch
from torchtask.utils import logger
""" This file provides tool functions for deep learning.
"""
def sigmoid_rampup(current, rampup_length):
""" Exponential rampup from https://arxiv.org/abs/1610.02242 .
"""
if rampup_length == 0:
return 1.0
else:
current = np.clip(current, 0.0, rampup_length)
phase = 1.0 - current / rampup_length
return float(np.exp(-5.0 * phase * phase))
def split_tensor_tuple(ttuple, start, end, reduce_dim=False):
""" Slice each tensor in the input tuple by channel-dim.
Arguments:
ttuple (tuple): tuple of a torch.Tensor
start (int): start index of slicing
end (int): end index of slicing
reduce_dim (bool): whether reduce the channel-dim when end - start == 1
Returns:
tuple: a sliced tensor tuple
"""
result = []
if reduce_dim:
assert end - start == 1
for t in ttuple:
if end - start == 1 and reduce_dim:
result.append(t[start, ...])
else:
result.append(t[start:end, ...])
return tuple(result)
def combine_tensor_tuple(ttuple1, ttuple2, dim):
result = []
assert len(ttuple1) == len(ttuple2)
for t1, t2 in zip(ttuple1, ttuple2):
result.append(torch.cat((t1, t2), dim=dim))
return tuple(result)
def create_model(mclass, mname, **kwargs):
""" Create a nn.Module and setup it on multiple GPUs.
"""
model = mclass(**kwargs)
model = torch.nn.DataParallel(model)
model = model.cuda()
logger.log_info(' ' + '=' * 76 + '\n {0} parameters \n{1}'.format(mname, model_str(model)))
return model
def model_str(module):
""" Output model structure and parameters number as strings.
"""
row_format = ' {name:<40} {shape:>20} = {total_size:>12,d}'
lines = [' ' + '-' * 76,]
params = list(module.named_parameters())
for name, param in params:
lines.append(row_format.format(name=name,
shape=' * '.join(str(p) for p in param.size()), total_size=param.numel()))
lines.append(' ' + '-' * 76)
lines.append(row_format.format(name='all parameters', shape='sum of above',
total_size=sum(int(param.numel()) for name, param in params)))
lines.append(' ' + '=' * 76)
lines.append('')
return '\n'.join(lines)
def pytorch_support(required_version='1.0.0', info_str=''):
if torch.__version__ < required_version:
logger.log_err('{0} required PyTorch >= {1}\n'
'However, current PyTorch == {2}\n'
.format(info_str, required_version, torch.__version__))
else:
return True