import mxnet as mx import numpy as np from nowcasting.config import cfg class ParamsReg(object): def __init__(self): self._params = {} self._old_params = [] def get(self, name, **kwargs): if name not in self._params: self._params[name] = mx.sym.Variable(name, dtype=np.float32, **kwargs) return self._params[name] def get_inner(self): return self._params def reset(self): self._old_params.append(self._params) self._params = {} _params = ParamsReg() def reset_regs(): global _params _params.reset() def activation(data, act_type, name=None): if act_type == "leaky": if name is None: act = mx.sym.LeakyReLU(data=data, slope=0.2) else: act = mx.sym.LeakyReLU(data=data, slope=0.2, name='%s_%s' %(name, act_type)) return act elif act_type == "identity": act = data else: if name is None: act = mx.sym.Activation(data=data, act_type=act_type) else: act = mx.sym.Activation(data=data, act_type=act_type, name='%s_%s' % (name, act_type)) return act def conv2d(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), dilate=(1, 1), no_bias=False, name=None, **kwargs): assert name is not None global _params weight = _params.get('%s_weight' % name, **kwargs) if no_bias: conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, weight=weight, dilate=dilate, no_bias=True, pad=pad, name=name, workspace=256) else: bias = _params.get('%s_bias' % name, wd_mult=0.0, **kwargs) conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, weight=weight, bias=bias, dilate=dilate, no_bias=no_bias, pad=pad, name=name, workspace=256) return conv def conv2d_bn_act(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), dilate=(1, 1), no_bias=False, act_type="relu", momentum=0.9, eps=1e-5 + 1e-12, fix_gamma=True, name=None, use_global_stats=False, **kwargs): conv = conv2d(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, dilate=dilate, no_bias=no_bias, name=name, **kwargs) assert name is not None global _params gamma = _params.get('%s_bn_gamma' % name, **kwargs) beta = _params.get('%s_bn_beta' % name, **kwargs) moving_mean = _params.get('%s_bn_moving_mean' % name, **kwargs) moving_var = _params.get('%s_bn_moving_var' % name, **kwargs) if fix_gamma: bn = mx.sym.BatchNorm(data=conv, beta=beta, gamma=gamma, moving_mean=moving_mean, moving_var=moving_var, fix_gamma=True, momentum=momentum, eps=eps, name='%s_bn' %name, use_global_stats=use_global_stats) else: bn = mx.sym.BatchNorm(data=conv, beta=beta, gamma=gamma, moving_mean=moving_mean, moving_var=moving_var, fix_gamma=False, momentum=momentum, eps=eps, name='%s_bn' % name, use_global_stats=use_global_stats) act = activation(bn, act_type=act_type, name=name) return act def conv2d_act(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), dilate=(1, 1), no_bias=False, act_type="relu", name=None, **kwargs): conv = conv2d(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, dilate=dilate, no_bias=no_bias, name=name, **kwargs) act = activation(conv, act_type=act_type, name=name) return act def deconv2d(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), adj=(0, 0), no_bias=True, target_shape=None, name="deconv2d", **kwargs): global _params assert name is not None weight = _params.get('%s_weight' % name, **kwargs) if no_bias: if target_shape is None: deconv = mx.sym.Deconvolution(data=data, num_filter=num_filter, kernel=kernel, adj=adj, stride=stride, no_bias=True, weight=weight, pad=pad, name=name) else: deconv = mx.sym.Deconvolution(data=data, num_filter=num_filter, kernel=kernel, adj=adj, stride=stride, target_shape=target_shape, no_bias=True, weight=weight, pad=pad, name=name) else: bias = _params.get('%s_bias' % name, wd_mult=0.0, **kwargs) if target_shape is None: deconv = mx.sym.Deconvolution(data=data, num_filter=num_filter, kernel=kernel, adj=adj, stride=stride, no_bias=no_bias, weight=weight, bias=bias, pad=pad, name=name) else: deconv = mx.sym.Deconvolution(data=data, num_filter=num_filter, kernel=kernel, adj=adj, stride=stride, target_shape=target_shape, no_bias=no_bias, weight=weight, bias=bias, pad=pad, name=name) return deconv def deconv2d_bn_act(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), adj=(0, 0), no_bias=True, target_shape=None, act_type="relu", momentum=0.9, eps=1e-5 + 1e-12, fix_gamma=True, name="deconv2d", use_global_stats=False, **kwargs): global _params deconv = deconv2d(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, adj=adj, target_shape=target_shape, no_bias=no_bias, name=name, **kwargs) gamma = _params.get('%s_bn_gamma' % name, **kwargs) beta = _params.get('%s_bn_beta' % name, **kwargs) moving_mean = _params.get('%s_bn_moving_mean' % name, **kwargs) moving_var = _params.get('%s_bn_moving_var' % name, **kwargs) if fix_gamma: bn = mx.sym.BatchNorm(data=deconv, beta=beta, gamma=gamma, moving_mean=moving_mean, moving_var=moving_var, fix_gamma=True, momentum=momentum, eps=eps, use_global_stats=use_global_stats, name='%s_bn' %name) else: bn = mx.sym.BatchNorm(data=deconv, beta=beta, gamma=gamma, moving_mean=moving_mean, moving_var=moving_var, fix_gamma=False, momentum=momentum, eps=eps, use_global_stats=use_global_stats, name='%s_bn' % name) act = activation(bn, act_type=act_type, name=name) return act def deconv2d_act(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), adj=(0, 0), no_bias=True, target_shape=None, act_type="relu", name="deconv2d", **kwargs): deconv = deconv2d(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, adj=adj, target_shape=target_shape, no_bias=no_bias, name=name, **kwargs) act = activation(deconv, act_type=act_type, name=name) return act def conv3d(data, num_filter, kernel=(1, 1, 1), stride=(1, 1, 1), pad=(0, 0, 0), dilate=(1, 1, 1), no_bias=False, name=None, **kwargs): return conv2d(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, dilate=dilate, no_bias=no_bias, name=name, **kwargs) def conv3d_bn_act(data, num_filter, height, width, kernel=(1, 1, 1), stride=(1, 1, 1), pad=(0, 0, 0), dilate=(1, 1, 1), no_bias=False, act_type="relu", momentum=0.9, eps=1e-5 + 1e-12, fix_gamma=True, name=None, use_global_stats=False, **kwargs): conv = conv3d(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, dilate=dilate, no_bias=no_bias, name=name, **kwargs) assert name is not None global _params gamma = _params.get('%s_bn_gamma' % name, **kwargs) beta = _params.get('%s_bn_beta' % name, **kwargs) moving_mean = _params.get('%s_bn_moving_mean' % name, **kwargs) moving_var = _params.get('%s_bn_moving_var' % name, **kwargs) conv = mx.symbol.reshape(conv, shape=(0, 0, -1, width)) if fix_gamma: bn = mx.sym.BatchNorm(data=conv, beta=beta, gamma=gamma, moving_mean=moving_mean, moving_var=moving_var, fix_gamma=True, momentum=momentum, eps=eps, use_global_stats=use_global_stats, name='%s_bn' %name) else: bn = mx.sym.BatchNorm(data=conv, beta=beta, gamma=gamma, moving_mean=moving_mean, moving_var=moving_var, fix_gamma=False, momentum=momentum, eps=eps, use_global_stats=use_global_stats, name='%s_bn' % name) bn = mx.symbol.reshape(bn, shape=(0, 0, -1, height, width)) act = activation(bn, act_type=act_type, name=name) return act def conv3d_act(data, num_filter, kernel=(1, 1, 1), stride=(1, 1, 1), pad=(0, 0, 0), dilate=(1, 1, 1), no_bias=False, act_type="relu", name=None, **kwargs): conv = conv3d(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, dilate=dilate, no_bias=no_bias, name=name, **kwargs) act = activation(conv, act_type=act_type, name=name) return act def deconv3d(data, num_filter, kernel=(1, 1, 1), stride=(1, 1, 1), pad=(0, 0, 0), adj=(0, 0, 0), no_bias=True, target_shape=None, name=None, **kwargs): return deconv2d(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, adj=adj, no_bias=no_bias, target_shape=target_shape, name=name, **kwargs) def deconv3d_bn_act(data, num_filter, height, width, kernel=(1, 1, 1), stride=(1, 1, 1), pad=(0, 0, 0), adj=(0, 0, 0), no_bias=True, target_shape=None, act_type="relu", momentum=0.9, eps=1e-5 + 1e-12, fix_gamma=True, name=None, use_global_stats=False, **kwargs): global _params deconv = deconv3d(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, adj=adj, target_shape=target_shape, no_bias=no_bias, name=name, **kwargs) gamma = _params.get('%s_bn_gamma' % name, **kwargs) beta = _params.get('%s_bn_beta' % name, **kwargs) moving_mean = _params.get('%s_bn_moving_mean' % name, **kwargs) moving_var = _params.get('%s_bn_moving_var' % name, **kwargs) deconv = mx.symbol.reshape(deconv, shape=(0, 0, -1, width)) if fix_gamma: bn = mx.sym.BatchNorm(data=deconv, beta=beta, gamma=gamma, moving_mean=moving_mean, moving_var=moving_var, fix_gamma=True, momentum=momentum, eps=eps, use_global_stats=use_global_stats, name='%s_bn' %name) else: bn = mx.sym.BatchNorm(data=deconv, beta=beta, gamma=gamma, moving_mean=moving_mean, moving_var=moving_var, fix_gamma=False, momentum=momentum, eps=eps, use_global_stats=use_global_stats, name='%s_bn' % name) bn = mx.symbol.reshape(bn, shape=(0, 0, -1, height, width)) act = activation(bn, act_type=act_type, name=name) return act def deconv3d_act(data, num_filter, kernel=(1, 1, 1), stride=(1, 1, 1), pad=(0, 0, 0), adj=(0, 0, 0), no_bias=True, target_shape=None, act_type="relu", name=None, **kwargs): deconv = deconv3d(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, adj=adj, target_shape=target_shape, no_bias=no_bias, name=name, **kwargs) act = activation(deconv, act_type=act_type, name=name) return act def fc_layer(data, num_hidden, no_bias=False, name="fc", **kwargs): assert name is not None global _params weight = _params.get('%s_weight' % name, **kwargs) if not no_bias: bias = _params.get('%s_bias' % name, **kwargs) fc = mx.sym.FullyConnected(data=data, weight=weight, bias=bias, num_hidden=num_hidden, no_bias=False, name=name, **kwargs) else: fc = mx.sym.FullyConnected(data=data, weight=weight, num_hidden=num_hidden, no_bias=True, name=name, **kwargs) return fc def fc_layer_act(data, num_hidden, no_bias=False, act_type="relu", name="fc", **kwargs): fc = fc_layer(data=data, num_hidden=num_hidden, no_bias=no_bias, name=name, **kwargs) act = activation(data=fc, act_type=act_type, name=name) return act def fc_layer_bn_act(data, num_hidden, no_bias=False, act_type="relu", momentum=0.9, eps=1e-5 + 1e-12, fix_gamma=True, name=None, use_global_stats=False, **kwargs): fc = fc_layer(data=data, num_hidden=num_hidden, no_bias=no_bias, name=name, **kwargs) assert name is not None global _params gamma = _params.get('%s_bn_gamma' % name, **kwargs) beta = _params.get('%s_bn_beta' % name, **kwargs) moving_mean = _params.get('%s_bn_moving_mean' % name, **kwargs) moving_var = _params.get('%s_bn_moving_var' % name, **kwargs) if fix_gamma: bn = mx.sym.BatchNorm(data=fc, beta=beta, gamma=gamma, moving_mean=moving_mean, moving_var=moving_var, fix_gamma=True, momentum=momentum, eps=eps, name='%s_bn' %name, use_global_stats=use_global_stats) else: bn = mx.sym.BatchNorm(data=fc, beta=beta, gamma=gamma, moving_mean=moving_mean, moving_var=moving_var, fix_gamma=False, momentum=momentum, eps=eps, name='%s_bn' % name, use_global_stats=use_global_stats) act = activation(bn, act_type=act_type, name=name) return act def downsample_module(data, num_filter, kernel, stride, pad, b_h_w, name, aggre_type=None): assert isinstance(data, list) data = mx.sym.concat(*data, dim=0) ret = conv2d_act(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, act_type=cfg.MODEL.CNN_ACT_TYPE, name=name + "_conv") return ret def upsample_module(data, num_filter, kernel, stride, pad, b_h_w, name, aggre_type=None): assert isinstance(data, list) data = mx.sym.concat(*data, dim=0) ret = deconv2d_act(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, act_type=cfg.MODEL.CNN_ACT_TYPE, name=name + "_deconv") return ret