sqfoo's picture
Upload 99 files
6021dd1 verified
import mxnet as mx
from mxnet.rnn import BaseRNNCell
from nowcasting.ops import activation
from nowcasting.operators.common import group_add
class MyBaseRNNCell(BaseRNNCell):
def __init__(self, prefix="MyBaseRNNCell", params=None):
super(MyBaseRNNCell, self).__init__(prefix=prefix, params=params)
def __call__(self, inputs, states, is_initial=False, ret_mid=False):
raise NotImplementedError()
def reset(self):
super(MyBaseRNNCell, self).reset()
self._curr_states = None
def get_current_states(self):
return self._curr_states
def unroll(self, length, inputs=None, begin_state=None, ret_mid=False,
input_prefix='', layout='TC', merge_outputs=False):
"""Unroll an RNN cell across time steps.
Parameters
----------
length : int
number of steps to unroll
inputs : Symbol, list of Symbol, or None
if inputs is a single Symbol (usually the output
of Embedding symbol), it should have shape
(batch_size, length, ...) if layout == 'NTC',
or (length, batch_size, ...) if layout == 'TNC'.
If inputs is a list of symbols (usually output of
previous unroll), they should all have shape
(batch_size, ...).
If inputs is None, Placeholder variables are
automatically created.
begin_state : nested list of Symbol
input states. Created by begin_state()
or output state of another cell. Created
from begin_state() if None.
input_prefix : str
prefix for automatically created input
placehodlers.
layout : str
layout of input symbol. Only used if inputs
is a single Symbol.
merge_outputs : bool
if False, return outputs as a list of Symbols.
If True, concatenate output across time steps
and return a single symbol with shape
(batch_size, length, ...) if layout == 'NTC',
or (length, batch_size, ...) if layout == 'TNC'.
Returns
-------
outputs : list of Symbol
output symbols.
states : Symbol or nested list of Symbol
has the same structure as begin_state()
mid_info : list of Symbol
"""
self.reset()
assert layout == 'TNC' or layout == 'TC'
if inputs is not None:
if isinstance(inputs, mx.sym.Symbol):
assert len(inputs.list_outputs()) == 1, \
"unroll doesn't allow grouped symbol as input. Please " \
"convert to list first or let unroll handle slicing"
if 'N' in layout:
inputs = mx.sym.SliceChannel(inputs, axis=0, num_outputs=length,
squeeze_axis=1)
else:
inputs = mx.sym.SliceChannel(inputs, axis=0, num_outputs=length)
else:
assert len(inputs) == length
else:
inputs = [None] * length
if begin_state is None:
states = self.begin_state()
else:
states = begin_state
outputs = []
mid_infos = []
for i in range(length):
output, states, mid_info = self(inputs=inputs[i], states=states,
is_initial=(i == 0 and (begin_state is None)),
ret_mid=True)
outputs.append(output)
mid_infos.extend(mid_info)
if merge_outputs:
outputs = [mx.sym.expand_dims(i, axis=0) for i in outputs]
outputs = mx.sym.Concat(*outputs, dim=0)
if ret_mid:
return outputs, states, mid_infos
else:
return outputs, states
class BaseStackRNN(object):
def __init__(self, base_rnn_class, stack_num=1,
name="BaseStackRNN", residual_connection=True,
**kwargs):
self._base_rnn_class = base_rnn_class
self._residual_connection = residual_connection
self._name = name
self._stack_num = stack_num
self._prefix = name + "_"
self._rnns = [base_rnn_class(prefix=self._name + "_%d" %i, **kwargs) for i in range(stack_num)]
self._init_counter = 0
self._state_info = None
def init_state_vars(self):
"""Initial state variable for this cell.
Parameters
----------
Returns
-------
state_vars : nested list of Symbol
starting states for first RNN step
"""
state_vars = []
for i, info in enumerate(self.state_info):
state = mx.sym.var(name='%s_begin_state_%s' % (self._name, self.state_postfix[i]), **info)
state_vars.append(state)
return state_vars
def concat_to_split(self, concat_states):
assert len(concat_states) == len(self.state_info)
split_states = [[] for i in range(self._stack_num)]
for i in range(len(self.state_info)):
channel_axis = self.state_info[i]['__layout__'].lower().find('c')
ele = mx.sym.split(concat_states[i], num_outputs=self._stack_num, axis=channel_axis)
for j in range(self._stack_num):
split_states[j].append(ele[j])
return split_states
def split_to_concat(self, split_states):
# Concat the states together
concat_states = []
for i in range(len(self.state_info)):
channel_axis = self.state_info[i]['__layout__'].lower().find('c')
concat_states.append(mx.sym.concat(*[ele[i] for ele in split_states],
dim=channel_axis))
return concat_states
def check_concat(self, states):
ret = not isinstance(states[0], list)
return ret
def to_concat(self, states):
if not self.check_concat(states):
states = self.split_to_concat(states)
return states
def to_split(self, states):
if self.check_concat(states):
states = self.concat_to_split(states)
return states
@property
def state_postfix(self):
return self._rnns[0].state_postfix
@property
def state_info(self):
if self._state_info is None:
info = []
for i in range(len(self._rnns[0].state_info)):
ele = {}
for rnn in self._rnns:
if 'shape' not in ele:
ele['shape'] = list(rnn.state_info[i]['shape'])
else:
channel_dim = rnn.state_info[i]['__layout__'].lower().find('c')
ele['shape'][channel_dim] += rnn.state_info[i]['shape'][channel_dim]
if '__layout__' not in ele:
ele['__layout__'] = rnn.state_info[i]['__layout__'].upper()
else:
assert rnn.state_info[i]['__layout__'] == ele['__layout__'].upper()
ele['shape'] = tuple(ele['shape'])
info.append(ele)
self._state_info = info
return info
else:
return self._state_info
def flatten_add_layout(self, states, blocked=False):
"""
Parameters
----------
states : list of list or list
Returns
-------
ret : list
"""
states = self.to_concat(states)
assert self.check_concat(states)
ret = []
for i, ele in enumerate(states):
if blocked:
ret.append(mx.sym.BlockGrad(ele, __layout__=self.state_info[i]['__layout__']))
else:
ele._set_attr(__layout__=self.state_info[i]['__layout__'])
ret.append(ele)
return ret
def reset(self):
for i in range(len(self._rnns)):
self._rnns[i].reset()
def unroll(self, length, inputs=None, begin_states=None, ret_mid=False):
if begin_states is None:
begin_states = self.init_state_vars()
begin_states = self.to_split(begin_states)
assert len(begin_states) == self._stack_num, len(begin_states)
for ele in begin_states:
assert len(ele) == len(self.state_info)
outputs = []
final_states = []
mid_infos = []
for i in range(len(self._rnns)):
rnn_out_list, rnn_final_states, rnn_mid_infos =\
self._rnns[i].unroll(length=length, inputs=inputs,
begin_state=begin_states[i],
layout="TC",
ret_mid=True)
if self._residual_connection and i > 0:
# Use residual connections
rnn_out_list = group_add(lhs=rnn_out_list, rhs=inputs)
inputs = rnn_out_list
outputs.append(rnn_out_list)
final_states.append(rnn_final_states)
mid_infos.append(rnn_mid_infos)
if ret_mid:
return outputs, final_states, mid_infos
else:
return outputs, final_states
class MyGRU(MyBaseRNNCell):
"""GRU cell.
Parameters
----------
num_hidden : int
number of units in output symbol
prefix : str, default 'rnn_'
prefix for name of layers
(and name of weight if params is None)
params : RNNParams or None
container for weight sharing between cells.
created if None.
"""
def __init__(self, num_hidden, zoneout=0.0, act_type="tanh", prefix='gru_', params=None):
super(MyGRU, self).__init__(prefix=prefix, params=params)
self._num_hidden = num_hidden
self._act_type = act_type
self._zoneout = zoneout
self._i2h_weight = self.params.get('i2h_weight')
self._i2h_bias = self.params.get('i2h_bias')
self._h2h_weight = self.params.get('h2h_weight')
self._h2h_bias = self.params.get('h2h_bias')
@property
def state_info(self):
"""shape(s) of states"""
return [{'shape': (0, self._num_hidden), '__layout__': "NC"}]
def __call__(self, inputs, states=None, is_initial=False, ret_mid=False):
name = '%s_t%d' % (self._prefix, self._counter)
self._counter += 1
if is_initial:
prev_h = self.begin_state()[0]
else:
prev_h = states[0]
assert states is not None
if inputs is not None:
inputs = mx.sym.reshape(inputs, shape=(0, -1))
i2h = mx.sym.FullyConnected(data=inputs,
num_hidden=self._num_hidden * 3,
weight=self._i2h_weight,
bias=self._i2h_bias,
name="%s_i2h" %name)
i2h_slice = mx.sym.SliceChannel(i2h, num_outputs=3, axis=1)
else:
i2h_slice = None
h2h = mx.sym.FullyConnected(data=prev_h,
num_hidden=self._num_hidden * 3,
weight=self._h2h_weight,
bias=self._h2h_bias,
name="%s_h2h" %name)
h2h_slice = mx.sym.SliceChannel(h2h, num_outputs=3, axis=1)
if i2h_slice is not None:
reset_gate = activation(i2h_slice[0] + h2h_slice[0], act_type="sigmoid",
name=name + "_r")
update_gate = activation(i2h_slice[1] + h2h_slice[1], act_type="sigmoid",
name=name + "_u")
new_mem = activation(i2h_slice[2] + reset_gate * h2h_slice[2],
act_type=self._act_type,
name=name + "_h")
else:
reset_gate = activation(h2h_slice[0], act_type="sigmoid",
name=name + "_r")
update_gate = activation(h2h_slice[1], act_type="sigmoid",
name=name + "_u")
new_mem = activation(reset_gate * h2h_slice[2],
act_type=self._act_type,
name=name + "_h")
next_h = update_gate * prev_h + (1 - update_gate) * new_mem
if self._zoneout > 0.0:
mask = mx.sym.Dropout(mx.sym.ones_like(prev_h), p=self._zoneout)
next_h = mx.sym.where(mask, next_h, prev_h)
self._curr_states = [next_h]
if not ret_mid:
return next_h, [next_h]
else:
return next_h, [next_h], []
if __name__ == '__main__':
from nowcasting.operators.conv_rnn import ConvGRU
brnn1 = BaseStackRNN(base_rnn_class=ConvGRU, stack_num=5,
b_h_w=(4, 32, 32), num_filter=32)
print(brnn1.state_info)
inputs = mx.sym.var(name="inputs", shape=(8, 4, 16, 32, 32))
outputs, final_states, mid_infos = brnn1.unroll(length=8, inputs=inputs, ret_mid=True)
print(len(outputs), len(outputs[0]))
print(len(final_states), len(final_states[0]))