|
|
import mxnet as mx |
|
|
from nowcasting.ops import * |
|
|
from nowcasting.operators.common import identity, grid_generator, group_add |
|
|
from nowcasting.operators.base_rnn import MyBaseRNNCell |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
class BaseConvRNN(MyBaseRNNCell): |
|
|
def __init__(self, num_filter, b_h_w, |
|
|
h2h_kernel=(3, 3), h2h_dilate=(1, 1), |
|
|
i2h_kernel=(3, 3), i2h_stride=(1, 1), |
|
|
i2h_pad=(1, 1), i2h_dilate=(1, 1), |
|
|
act_type="tanh", prefix="ConvRNN", params=None): |
|
|
super(BaseConvRNN, self).__init__(prefix=prefix + "_", params=params) |
|
|
self._num_filter = num_filter |
|
|
self._h2h_kernel = h2h_kernel |
|
|
assert (self._h2h_kernel[0] % 2 == 1) and (self._h2h_kernel[1] % 2 == 1), \ |
|
|
"Only support odd number, get h2h_kernel= %s" % str(h2h_kernel) |
|
|
self._h2h_pad = (h2h_dilate[0] * (h2h_kernel[0] - 1) // 2, |
|
|
h2h_dilate[1] * (h2h_kernel[1] - 1) // 2) |
|
|
self._h2h_dilate = h2h_dilate |
|
|
self._i2h_kernel = i2h_kernel |
|
|
self._i2h_stride = i2h_stride |
|
|
self._i2h_pad = i2h_pad |
|
|
self._i2h_dilate = i2h_dilate |
|
|
self._act_type = act_type |
|
|
assert len(b_h_w) == 3 |
|
|
i2h_dilate_ksize_h = 1 + (self._i2h_kernel[0] - 1) * self._i2h_dilate[0] |
|
|
i2h_dilate_ksize_w = 1 + (self._i2h_kernel[1] - 1) * self._i2h_dilate[1] |
|
|
self._batch_size, self._height, self._width = b_h_w |
|
|
self._state_height = (self._height + 2 * self._i2h_pad[0] - i2h_dilate_ksize_h)\ |
|
|
// self._i2h_stride[0] + 1 |
|
|
self._state_width = (self._width + 2 * self._i2h_pad[1] - i2h_dilate_ksize_w) \ |
|
|
// self._i2h_stride[1] + 1 |
|
|
print(self._prefix, self._state_height, self._state_width) |
|
|
self._curr_states = None |
|
|
self._counter = 0 |
|
|
|
|
|
|
|
|
class ConvRNN(BaseConvRNN): |
|
|
def __init__(self, num_filter, b_h_w, |
|
|
h2h_kernel=(3, 3), h2h_dilate=(1, 1), |
|
|
i2h_kernel=(3, 3), i2h_stride=(1, 1), |
|
|
i2h_pad=(1, 1), i2h_dilate=(1, 1), |
|
|
act_type="leaky", |
|
|
layer_norm=False, |
|
|
prefix="ConvRNN", |
|
|
params=None): |
|
|
super(ConvRNN, self).__init__(num_filter=num_filter, |
|
|
b_h_w=b_h_w, |
|
|
h2h_kernel=h2h_kernel, |
|
|
h2h_dilate=h2h_dilate, |
|
|
i2h_kernel=i2h_kernel, |
|
|
i2h_pad=i2h_pad, |
|
|
i2h_dilate=i2h_dilate, |
|
|
act_type=act_type, |
|
|
prefix=prefix, |
|
|
params=params) |
|
|
self._layer_norm = layer_norm |
|
|
self.i2h_weight = self.params.get('i2h_weight') |
|
|
self.i2h_bias = self.params.get('i2h_bias') |
|
|
self.h2h_weight = self.params.get('h2h_weight') |
|
|
self.h2h_bias = self.params.get('h2h_bias', init=mx.init.Normal()) |
|
|
|
|
|
@property |
|
|
def state_info(self): |
|
|
return [{'shape': (self._batch_size, self._num_filter, |
|
|
self._state_height, self._state_width), |
|
|
'__layout__': "NCHW"}] |
|
|
|
|
|
def __call__(self, inputs, states=None, is_initial=False, ret_mid=False): |
|
|
name = '%s_t%d' % (self._prefix, self._counter) |
|
|
self._counter += 1 |
|
|
states = self.begin_state()[0] if is_initial else states[0] |
|
|
assert states is not None |
|
|
if inputs is not None: |
|
|
i2h = mx.sym.Convolution(data=inputs, |
|
|
weight=self.i2h_weight, |
|
|
bias=self.i2h_bias, |
|
|
kernel=self._i2h_kernel, |
|
|
stride=self._i2h_stride, |
|
|
dilate=self._i2h_dilate, |
|
|
pad=self._i2h_pad, |
|
|
num_filter=self._num_filter, |
|
|
name="%s_i2h" % name) |
|
|
else: |
|
|
i2h = None |
|
|
h2h = mx.sym.Convolution(data=states, |
|
|
weight=self.h2h_weight, |
|
|
bias=self.h2h_bias, |
|
|
kernel=self._h2h_kernel, |
|
|
stride=(1, 1), |
|
|
dilate=self._h2h_dilate, |
|
|
pad=self._h2h_pad, |
|
|
num_filter=self._num_filter, |
|
|
name="%s_h2h" % name) |
|
|
if i2h is not None: |
|
|
if self._layer_norm: |
|
|
next_h = activation(layer_normalization(i2h + h2h, |
|
|
num_filters=self._num_filter, |
|
|
name=self._prefix + "ln"), |
|
|
act_type=self._act_type, name=name + "_state") |
|
|
else: |
|
|
next_h = activation(i2h + h2h, |
|
|
act_type=self._act_type, name=name + "_state") |
|
|
else: |
|
|
if self._layer_norm: |
|
|
next_h = activation(layer_normalization(h2h, |
|
|
num_filters=self._num_filter, |
|
|
name=self._prefix + "ln"), |
|
|
act_type=self._act_type, name=name + "_state") |
|
|
else: |
|
|
next_h = activation(h2h, act_type=self._act_type, name=name + "_state") |
|
|
|
|
|
self._curr_states = [next_h] |
|
|
if not ret_mid: |
|
|
return next_h, [next_h] |
|
|
else: |
|
|
return next_h, [next_h], [i2h, h2h] |
|
|
|
|
|
|
|
|
class ConvGRU(BaseConvRNN): |
|
|
def __init__(self, num_filter, b_h_w, zoneout=0.0, |
|
|
h2h_kernel=(3, 3), h2h_dilate=(1, 1), |
|
|
i2h_kernel=(3, 3), i2h_stride=(1, 1), i2h_pad=(1, 1), i2h_dilate=(1, 1), |
|
|
i2h_adj=(0, 0), no_i2h_bias=False, use_deconv=False, |
|
|
act_type="leaky", prefix="ConvGRU", lr_mult=1.0): |
|
|
"""Initializing a ConvGRU/DeconvGRU |
|
|
|
|
|
r_t = \sigma(W_r \ast x_t + R_r \ast h_{t-1} + b_{W_r} + b_{R_r}) |
|
|
u_t = \sigma(W_u \ast x_t + R_u \ast h_{t-1} + b_{W_u} + b_{R_u}) |
|
|
h^\prime_t = tanh(W_h \ast x_t + r_t \circ (R_h \ast h_{t-1} + b_{R_h}) + b_{W_h}) |
|
|
h_t = (1 - u_t) \circ h^\prime_t + u_t \circ h_{t-1} |
|
|
|
|
|
Parameters: (reset_gate, update_gate, new_mem) |
|
|
W_{i2h} = [W_r, W_u, W_h] |
|
|
b_{i2h} = [b_{W_r}, b_{W_u}, b_{W_h}] |
|
|
W_{h2h} = [R_r, R_u, R_h] |
|
|
b_{h2h} = [b_{R_r}, b_{R_u}, b_{R_h}] |
|
|
|
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
num_hidden : int |
|
|
hidden_act_type : str |
|
|
name : str |
|
|
""" |
|
|
super(ConvGRU, self).__init__(num_filter=num_filter, |
|
|
b_h_w=b_h_w, |
|
|
h2h_kernel=h2h_kernel, |
|
|
h2h_dilate=h2h_dilate, |
|
|
i2h_kernel=i2h_kernel, |
|
|
i2h_pad=i2h_pad, |
|
|
i2h_stride=i2h_stride, |
|
|
i2h_dilate=i2h_dilate, |
|
|
act_type=act_type, |
|
|
prefix=prefix) |
|
|
self._no_i2h_bias = no_i2h_bias |
|
|
self._i2h_adj = i2h_adj |
|
|
self._use_deconv = use_deconv |
|
|
if self._no_i2h_bias: |
|
|
assert use_deconv |
|
|
self._zoneout = zoneout |
|
|
self.i2h_weight = self.params.get("i2h_weight", lr_mult=lr_mult) |
|
|
self.i2h_bias = self.params.get("i2h_bias", lr_mult=lr_mult) |
|
|
self.h2h_weight = self.params.get("h2h_weight", lr_mult=lr_mult) |
|
|
self.h2h_bias = self.params.get("h2h_bias", lr_mult=lr_mult) |
|
|
|
|
|
@property |
|
|
def state_postfix(self): |
|
|
return ['h'] |
|
|
|
|
|
@property |
|
|
def state_info(self): |
|
|
return [{'shape': (self._batch_size, self._num_filter, |
|
|
self._state_height, self._state_width), |
|
|
'__layout__': "NCHW"}] |
|
|
|
|
|
def __call__(self, inputs, states=None, is_initial=False, ret_mid=False): |
|
|
name = '%s_t%d' % (self._prefix, self._counter) |
|
|
self._counter += 1 |
|
|
if is_initial: |
|
|
states = self.begin_state()[0] |
|
|
else: |
|
|
states = states[0] |
|
|
assert states is not None |
|
|
if inputs is not None: |
|
|
if self._use_deconv: |
|
|
if self._no_i2h_bias: |
|
|
i2h = mx.sym.Deconvolution(data=inputs, |
|
|
weight=self.i2h_weight, |
|
|
kernel=self._i2h_kernel, |
|
|
stride=self._i2h_stride, |
|
|
pad=self._i2h_pad, |
|
|
adj=self._i2h_adj, |
|
|
no_bias=True, |
|
|
num_filter=self._num_filter * 3, |
|
|
name="%s_i2h" % name) |
|
|
else: |
|
|
i2h = mx.sym.Deconvolution(data=inputs, |
|
|
weight=self.i2h_weight, |
|
|
bias=self.i2h_bias, |
|
|
kernel=self._i2h_kernel, |
|
|
stride=self._i2h_stride, |
|
|
pad=self._i2h_pad, |
|
|
adj=self._i2h_adj, |
|
|
num_filter=self._num_filter * 3, |
|
|
name="%s_i2h" % name) |
|
|
else: |
|
|
i2h = mx.sym.Convolution(data=inputs, |
|
|
weight=self.i2h_weight, |
|
|
bias=self.i2h_bias, |
|
|
kernel=self._i2h_kernel, |
|
|
stride=self._i2h_stride, |
|
|
dilate=self._i2h_dilate, |
|
|
pad=self._i2h_pad, |
|
|
num_filter=self._num_filter * 3, |
|
|
name="%s_i2h" % name) |
|
|
i2h_slice = mx.sym.SliceChannel(i2h, num_outputs=3, axis=1) |
|
|
else: |
|
|
i2h_slice = None |
|
|
prev_h = states |
|
|
print("h2h_dilate=", self._h2h_dilate) |
|
|
h2h = mx.sym.Convolution(data=prev_h, |
|
|
weight=self.h2h_weight, |
|
|
bias=self.h2h_bias, |
|
|
no_bias=False, |
|
|
kernel=self._h2h_kernel, |
|
|
stride=(1, 1), |
|
|
dilate=self._h2h_dilate, |
|
|
pad=self._h2h_pad, |
|
|
num_filter=self._num_filter * 3, |
|
|
name="%s_h2h" % name) |
|
|
h2h_slice = mx.sym.SliceChannel(h2h, num_outputs=3, axis=1) |
|
|
if i2h_slice is not None: |
|
|
reset_gate = mx.sym.Activation(i2h_slice[0] + h2h_slice[0], act_type="sigmoid", |
|
|
name=name + "_r") |
|
|
update_gate = mx.sym.Activation(i2h_slice[1] + h2h_slice[1], act_type="sigmoid", |
|
|
name=name + "_u") |
|
|
new_mem = activation(i2h_slice[2] + reset_gate * h2h_slice[2], |
|
|
act_type=self._act_type, |
|
|
name=name + "_h") |
|
|
else: |
|
|
reset_gate = mx.sym.Activation(h2h_slice[0], act_type="sigmoid", |
|
|
name=name + "_r") |
|
|
update_gate = mx.sym.Activation(h2h_slice[1], act_type="sigmoid", |
|
|
name=name + "_u") |
|
|
new_mem = activation(reset_gate * h2h_slice[2], |
|
|
act_type=self._act_type, |
|
|
name=name + "_h") |
|
|
next_h = update_gate * prev_h + (1 - update_gate) * new_mem |
|
|
if self._zoneout > 0.0: |
|
|
mask = mx.sym.Dropout(mx.sym.ones_like(prev_h), p=self._zoneout) |
|
|
next_h = mx.sym.where(mask, next_h, prev_h) |
|
|
self._curr_states = [next_h] |
|
|
if not ret_mid: |
|
|
return next_h, [next_h] |
|
|
else: |
|
|
return next_h, [next_h], [] |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
data = mx.sym.Variable('data') |
|
|
data = mx.sym.SliceChannel(data, axis=0, num_outputs=11, squeeze_axis=True) |
|
|
conv_gru1 = ConvGRU(num_filter=100, b_h_w=(3, 40, 40), |
|
|
prefix="conv_gru1") |
|
|
out, states = conv_gru1(inputs=data[0], is_initial=True) |
|
|
for i in range(1, 11): |
|
|
out, states = conv_gru1(inputs=data[i], states=states) |
|
|
conv_gru_forward_backward_time =\ |
|
|
mx.test_utils.check_speed(out, |
|
|
location={'data': np.random.normal(size=(11, 3, 128, 40, 40))}, |
|
|
N=2) |
|
|
net = mx.mod.Module(out, data_names=['data',], label_names=None, context=mx.gpu()) |
|
|
net.bind(data_shapes=[('data', (11, 3, 128, 40, 40))], |
|
|
grad_req='add') |
|
|
net.init_params() |
|
|
net.forward(mx.io.DataBatch(data=[mx.random.normal(shape=(11, 3, 128, 40, 40))], label=None), is_train=False) |
|
|
print(net.get_outputs()[0].asnumpy()) |
|
|
|
|
|
|
|
|
data = mx.sym.Variable('data') |
|
|
data = mx.sym.SliceChannel(data, axis=0, num_outputs=11, squeeze_axis=True) |
|
|
conv_rnn1 = ConvRNN(num_filter=100, b_h_w=(3, 40, 40), |
|
|
prefix="conv_rnn1") |
|
|
out, states = conv_rnn1(inputs=data[0], is_initial=True) |
|
|
for i in range(1, 11): |
|
|
out, states = conv_rnn1(inputs=data[i], states=states) |
|
|
conv_rnn_forward_backward_time = \ |
|
|
mx.test_utils.check_speed(out, |
|
|
location={'data': np.random.normal(size=(11, 3, 128, 40, 40))}, |
|
|
N=2) |
|
|
net = mx.mod.Module(out, data_names=['data', ], label_names=None, context=mx.gpu()) |
|
|
net.bind(data_shapes=[('data', (11, 3, 128, 40, 40))], |
|
|
grad_req='add') |
|
|
net.init_params() |
|
|
net.forward(mx.io.DataBatch(data=[mx.random.normal(shape=(11, 3, 128, 40, 40))], label=None), |
|
|
is_train=False) |
|
|
print(net.get_outputs()[0].asnumpy()) |
|
|
|
|
|
print("ConvGRU Time:", conv_gru_forward_backward_time) |
|
|
print("ConvRNN Time:", conv_rnn_forward_backward_time) |
|
|
|