File size: 2,708 Bytes
4c62147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import numpy as np

import torch

from torchtask.utils import logger


""" This file provides tool functions for deep learning.
"""


def sigmoid_rampup(current, rampup_length):
    """ Exponential rampup from https://arxiv.org/abs/1610.02242 . 
    """
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current, 0.0, rampup_length)
        phase = 1.0 - current / rampup_length
        return float(np.exp(-5.0 * phase * phase))



def split_tensor_tuple(ttuple, start, end, reduce_dim=False):
    """ Slice each tensor in the input tuple by channel-dim.

    Arguments:
        ttuple (tuple): tuple of a torch.Tensor
        start (int): start index of slicing
        end (int): end index of slicing
        reduce_dim (bool): whether reduce the channel-dim when end - start == 1
    
    Returns:
        tuple: a sliced tensor tuple
    """

    result = []

    if reduce_dim:
        assert end - start == 1

    for t in ttuple:
        if end - start == 1 and reduce_dim:
            result.append(t[start, ...])       
        else:
            result.append(t[start:end, ...])

    return tuple(result)


def combine_tensor_tuple(ttuple1, ttuple2, dim):
    result = []

    assert len(ttuple1) == len(ttuple2)

    for t1, t2 in zip(ttuple1, ttuple2):
        result.append(torch.cat((t1, t2), dim=dim))

    return tuple(result)


def create_model(mclass, mname, **kwargs):
    """ Create a nn.Module and setup it on multiple GPUs.
    """
    model = mclass(**kwargs)
    model = torch.nn.DataParallel(model)
    model = model.cuda()
    
    logger.log_info('  ' + '=' * 76 + '\n  {0} parameters \n{1}'.format(mname, model_str(model)))
    return model


def model_str(module):
    """ Output model structure and parameters number as strings.
    """
    row_format = '  {name:<40} {shape:>20} = {total_size:>12,d}'
    lines = ['  ' + '-' * 76,]

    params = list(module.named_parameters())
    for name, param in params:
        lines.append(row_format.format(name=name,
            shape=' * '.join(str(p) for p in param.size()), total_size=param.numel()))

    lines.append('  ' + '-' * 76)
    lines.append(row_format.format(name='all parameters', shape='sum of above',
        total_size=sum(int(param.numel()) for name, param in params)))
    lines.append('  ' + '=' * 76)
    lines.append('')

    return '\n'.join(lines)


def pytorch_support(required_version='1.0.0', info_str=''):
    if torch.__version__ < required_version:
        logger.log_err('{0} required PyTorch >= {1}\n'
                       'However, current PyTorch == {2}\n'
                       .format(info_str, required_version, torch.__version__))
    else:
        return True