import numpy as np import torch from torchtask.utils import logger """ This file provides tool functions for deep learning. """ def sigmoid_rampup(current, rampup_length): """ Exponential rampup from https://arxiv.org/abs/1610.02242 . """ if rampup_length == 0: return 1.0 else: current = np.clip(current, 0.0, rampup_length) phase = 1.0 - current / rampup_length return float(np.exp(-5.0 * phase * phase)) def split_tensor_tuple(ttuple, start, end, reduce_dim=False): """ Slice each tensor in the input tuple by channel-dim. Arguments: ttuple (tuple): tuple of a torch.Tensor start (int): start index of slicing end (int): end index of slicing reduce_dim (bool): whether reduce the channel-dim when end - start == 1 Returns: tuple: a sliced tensor tuple """ result = [] if reduce_dim: assert end - start == 1 for t in ttuple: if end - start == 1 and reduce_dim: result.append(t[start, ...]) else: result.append(t[start:end, ...]) return tuple(result) def combine_tensor_tuple(ttuple1, ttuple2, dim): result = [] assert len(ttuple1) == len(ttuple2) for t1, t2 in zip(ttuple1, ttuple2): result.append(torch.cat((t1, t2), dim=dim)) return tuple(result) def create_model(mclass, mname, **kwargs): """ Create a nn.Module and setup it on multiple GPUs. """ model = mclass(**kwargs) model = torch.nn.DataParallel(model) model = model.cuda() logger.log_info(' ' + '=' * 76 + '\n {0} parameters \n{1}'.format(mname, model_str(model))) return model def model_str(module): """ Output model structure and parameters number as strings. """ row_format = ' {name:<40} {shape:>20} = {total_size:>12,d}' lines = [' ' + '-' * 76,] params = list(module.named_parameters()) for name, param in params: lines.append(row_format.format(name=name, shape=' * '.join(str(p) for p in param.size()), total_size=param.numel())) lines.append(' ' + '-' * 76) lines.append(row_format.format(name='all parameters', shape='sum of above', total_size=sum(int(param.numel()) for name, param in params))) lines.append(' ' + '=' * 76) lines.append('') return '\n'.join(lines) def pytorch_support(required_version='1.0.0', info_str=''): if torch.__version__ < required_version: logger.log_err('{0} required PyTorch >= {1}\n' 'However, current PyTorch == {2}\n' .format(info_str, required_version, torch.__version__)) else: return True