sqfoo's picture
Upload 99 files
6021dd1 verified
import mxnet as mx
import logging
from nowcasting.ops import *
from nowcasting.operators.common import identity, grid_generator, group_add, constant, save_npy
from nowcasting.operators.conv_rnn import BaseConvRNN
import numpy as np
def flow_conv(data, num_filter, flows, weight, bias, name):
assert isinstance(flows, list)
warpped_data = []
for i in range(len(flows)):
flow = flows[i]
grid = mx.sym.GridGenerator(data=-flow, transform_type="warp")
ele_dat = mx.sym.BilinearSampler(data=data, grid=grid)
warpped_data.append(ele_dat)
data = mx.sym.concat(*warpped_data, dim=1)
ret = mx.sym.Convolution(data=data,
num_filter=num_filter,
kernel=(1, 1),
weight=weight,
bias=bias,
name=name)
return ret
class TrajGRU(BaseConvRNN):
def __init__(self, b_h_w, num_filter, zoneout=0.0, L=5,
i2h_kernel=(3, 3), i2h_stride=(1, 1), i2h_pad=(1, 1),
h2h_kernel=(5, 5), h2h_dilate=(1, 1),
act_type="leaky",
prefix="TrajGRU", lr_mult=1.0):
super(TrajGRU, self).__init__(num_filter=num_filter,
b_h_w=b_h_w,
h2h_kernel=h2h_kernel,
h2h_dilate=h2h_dilate,
i2h_kernel=i2h_kernel,
i2h_pad=i2h_pad,
i2h_stride=i2h_stride,
act_type=act_type,
prefix=prefix)
self._L = L
self._zoneout = zoneout
self.i2f_conv1_weight = self.params.get("i2f_conv1_weight", lr_mult=lr_mult)
self.i2f_conv1_bias = self.params.get("i2f_conv1_bias", lr_mult=lr_mult)
self.h2f_conv1_weight = self.params.get("h2f_conv1_weight", lr_mult=lr_mult)
self.h2f_conv1_bias = self.params.get("h2f_conv1_bias", lr_mult=lr_mult)
self.f_conv2_weight = self.params.get("f_conv2_weight", lr_mult=lr_mult)
self.f_conv2_bias = self.params.get("f_conv2_bias", lr_mult=lr_mult)
if cfg.MODEL.TRAJRNN.INIT_GRID:
logging.info("TrajGRU: Initialize Grid Using Zeros!")
self.f_out_weight = self.params.get("f_out_weight",
lr_mult=lr_mult * cfg.MODEL.TRAJRNN.FLOW_LR_MULT,
init=mx.init.Zero())
self.f_out_bias = self.params.get("f_out_bias",
lr_mult=lr_mult * cfg.MODEL.TRAJRNN.FLOW_LR_MULT,
init=mx.init.Zero())
else:
self.f_out_weight = self.params.get("f_out_weight", lr_mult=lr_mult)
self.f_out_bias = self.params.get("f_out_bias", lr_mult=lr_mult)
self.i2h_weight = self.params.get("i2h_weight", lr_mult=lr_mult)
self.i2h_bias = self.params.get("i2h_bias", lr_mult=lr_mult)
self.h2h_weight = self.params.get("h2h_weight", lr_mult=lr_mult)
self.h2h_bias = self.params.get("h2h_bias", lr_mult=lr_mult)
@property
def state_postfix(self):
return ['h']
@property
def state_info(self):
return [{'shape': (self._batch_size, self._num_filter,
self._state_height, self._state_width),
'__layout__': "NCHW"}]
def _flow_generator(self, inputs, states, prefix):
if inputs is not None:
i2f_conv1 = mx.sym.Convolution(data=inputs,
weight=self.i2f_conv1_weight,
bias=self.i2f_conv1_bias,
kernel=(5, 5),
dilate=(1, 1),
pad=(2, 2),
num_filter=32,
name="%s_i2f_conv1" % prefix)
else:
i2f_conv1 = None
h2f_conv1 = mx.sym.Convolution(data=states,
weight=self.h2f_conv1_weight,
bias=self.h2f_conv1_bias,
kernel=(5, 5),
dilate=(1, 1),
pad=(2, 2),
num_filter=32,
name="%s_h2f_conv1" % prefix)
f_conv1 = i2f_conv1 + h2f_conv1 if i2f_conv1 is not None else h2f_conv1
f_conv1 = activation(f_conv1, act_type=self._act_type)
# f_conv2 = mx.sym.Convolution(data=f_conv1,
# weight=self.f_conv2_weight,
# bias=self.f_conv2_bias,
# kernel=(5, 5),
# dilate=(1, 1),
# pad=(2, 2),
# num_filter=32,
# name="%s_f_conv2" %prefix)
# f_conv2 = activation(f_conv2, act_type=self._act_type)
flows = mx.sym.Convolution(data=f_conv1,
weight=self.f_out_weight,
bias=self.f_out_bias,
kernel=(5, 5),
pad=(2, 2),
num_filter=self._L * 2)
if cfg.MODEL.TRAJRNN.SAVE_MID_RESULTS:
import os
flows = save_npy(flows, save_name="%s_flow" %prefix,
save_dir=os.path.join(cfg.MODEL.SAVE_DIR, "flows"))
flows = mx.sym.split(flows, num_outputs=self._L, axis=1)
flows = [flows[i] for i in range(self._L)]
return flows
def __call__(self, inputs, states=None, is_initial=False, ret_mid=False):
self._counter += 1
name = '%s_t%d' % (self._prefix, self._counter)
if is_initial:
states = self.begin_state()[0]
else:
states = states[0]
assert states is not None
if inputs is not None:
i2h = mx.sym.Convolution(data=inputs,
weight=self.i2h_weight,
bias=self.i2h_bias,
kernel=self._i2h_kernel,
stride=self._i2h_stride,
dilate=self._i2h_dilate,
pad=self._i2h_pad,
num_filter=self._num_filter * 3,
name="%s_i2h" % name)
i2h_slice = mx.sym.SliceChannel(i2h, num_outputs=3, axis=1)
else:
i2h_slice = None
prev_h = states
flows = self._flow_generator(inputs=inputs, states=states, prefix=name)
# flows[0] = identity(flows[0], input_debug=True)
h2h = flow_conv(data=prev_h, num_filter=self._num_filter * 3, flows=flows,
weight=self.h2h_weight, bias=self.h2h_bias, name="%s_h2h" % name)
h2h_slice = mx.sym.SliceChannel(h2h, num_outputs=3, axis=1)
if i2h_slice is not None:
reset_gate = mx.sym.Activation(i2h_slice[0] + h2h_slice[0], act_type="sigmoid",
name=name + "_r")
update_gate = mx.sym.Activation(i2h_slice[1] + h2h_slice[1], act_type="sigmoid",
name=name + "_u")
new_mem = activation(i2h_slice[2] + reset_gate * h2h_slice[2],
act_type=self._act_type,
name=name + "_h")
else:
reset_gate = mx.sym.Activation(h2h_slice[0], act_type="sigmoid",
name=name + "_r")
update_gate = mx.sym.Activation(h2h_slice[1], act_type="sigmoid",
name=name + "_u")
new_mem = activation(reset_gate * h2h_slice[2],
act_type=self._act_type,
name=name + "_h")
next_h = update_gate * prev_h + (1 - update_gate) * new_mem
if self._zoneout > 0.0:
mask = mx.sym.Dropout(mx.sym.ones_like(prev_h), p=self._zoneout)
next_h = mx.sym.where(mask, next_h, prev_h)
self._curr_states = [next_h]
if not ret_mid:
return next_h, [next_h]
else:
return next_h, [next_h], []