|
|
import mxnet as mx |
|
|
import logging |
|
|
from nowcasting.ops import * |
|
|
from nowcasting.operators.common import identity, grid_generator, group_add, constant, save_npy |
|
|
from nowcasting.operators.conv_rnn import BaseConvRNN |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
def flow_conv(data, num_filter, flows, weight, bias, name): |
|
|
assert isinstance(flows, list) |
|
|
warpped_data = [] |
|
|
for i in range(len(flows)): |
|
|
flow = flows[i] |
|
|
grid = mx.sym.GridGenerator(data=-flow, transform_type="warp") |
|
|
ele_dat = mx.sym.BilinearSampler(data=data, grid=grid) |
|
|
warpped_data.append(ele_dat) |
|
|
data = mx.sym.concat(*warpped_data, dim=1) |
|
|
ret = mx.sym.Convolution(data=data, |
|
|
num_filter=num_filter, |
|
|
kernel=(1, 1), |
|
|
weight=weight, |
|
|
bias=bias, |
|
|
name=name) |
|
|
return ret |
|
|
|
|
|
|
|
|
class TrajGRU(BaseConvRNN): |
|
|
def __init__(self, b_h_w, num_filter, zoneout=0.0, L=5, |
|
|
i2h_kernel=(3, 3), i2h_stride=(1, 1), i2h_pad=(1, 1), |
|
|
h2h_kernel=(5, 5), h2h_dilate=(1, 1), |
|
|
act_type="leaky", |
|
|
prefix="TrajGRU", lr_mult=1.0): |
|
|
super(TrajGRU, self).__init__(num_filter=num_filter, |
|
|
b_h_w=b_h_w, |
|
|
h2h_kernel=h2h_kernel, |
|
|
h2h_dilate=h2h_dilate, |
|
|
i2h_kernel=i2h_kernel, |
|
|
i2h_pad=i2h_pad, |
|
|
i2h_stride=i2h_stride, |
|
|
act_type=act_type, |
|
|
prefix=prefix) |
|
|
self._L = L |
|
|
self._zoneout = zoneout |
|
|
self.i2f_conv1_weight = self.params.get("i2f_conv1_weight", lr_mult=lr_mult) |
|
|
self.i2f_conv1_bias = self.params.get("i2f_conv1_bias", lr_mult=lr_mult) |
|
|
self.h2f_conv1_weight = self.params.get("h2f_conv1_weight", lr_mult=lr_mult) |
|
|
self.h2f_conv1_bias = self.params.get("h2f_conv1_bias", lr_mult=lr_mult) |
|
|
self.f_conv2_weight = self.params.get("f_conv2_weight", lr_mult=lr_mult) |
|
|
self.f_conv2_bias = self.params.get("f_conv2_bias", lr_mult=lr_mult) |
|
|
if cfg.MODEL.TRAJRNN.INIT_GRID: |
|
|
logging.info("TrajGRU: Initialize Grid Using Zeros!") |
|
|
self.f_out_weight = self.params.get("f_out_weight", |
|
|
lr_mult=lr_mult * cfg.MODEL.TRAJRNN.FLOW_LR_MULT, |
|
|
init=mx.init.Zero()) |
|
|
self.f_out_bias = self.params.get("f_out_bias", |
|
|
lr_mult=lr_mult * cfg.MODEL.TRAJRNN.FLOW_LR_MULT, |
|
|
init=mx.init.Zero()) |
|
|
else: |
|
|
self.f_out_weight = self.params.get("f_out_weight", lr_mult=lr_mult) |
|
|
self.f_out_bias = self.params.get("f_out_bias", lr_mult=lr_mult) |
|
|
self.i2h_weight = self.params.get("i2h_weight", lr_mult=lr_mult) |
|
|
self.i2h_bias = self.params.get("i2h_bias", lr_mult=lr_mult) |
|
|
self.h2h_weight = self.params.get("h2h_weight", lr_mult=lr_mult) |
|
|
self.h2h_bias = self.params.get("h2h_bias", lr_mult=lr_mult) |
|
|
|
|
|
@property |
|
|
def state_postfix(self): |
|
|
return ['h'] |
|
|
|
|
|
@property |
|
|
def state_info(self): |
|
|
return [{'shape': (self._batch_size, self._num_filter, |
|
|
self._state_height, self._state_width), |
|
|
'__layout__': "NCHW"}] |
|
|
|
|
|
def _flow_generator(self, inputs, states, prefix): |
|
|
if inputs is not None: |
|
|
i2f_conv1 = mx.sym.Convolution(data=inputs, |
|
|
weight=self.i2f_conv1_weight, |
|
|
bias=self.i2f_conv1_bias, |
|
|
kernel=(5, 5), |
|
|
dilate=(1, 1), |
|
|
pad=(2, 2), |
|
|
num_filter=32, |
|
|
name="%s_i2f_conv1" % prefix) |
|
|
else: |
|
|
i2f_conv1 = None |
|
|
h2f_conv1 = mx.sym.Convolution(data=states, |
|
|
weight=self.h2f_conv1_weight, |
|
|
bias=self.h2f_conv1_bias, |
|
|
kernel=(5, 5), |
|
|
dilate=(1, 1), |
|
|
pad=(2, 2), |
|
|
num_filter=32, |
|
|
name="%s_h2f_conv1" % prefix) |
|
|
f_conv1 = i2f_conv1 + h2f_conv1 if i2f_conv1 is not None else h2f_conv1 |
|
|
f_conv1 = activation(f_conv1, act_type=self._act_type) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flows = mx.sym.Convolution(data=f_conv1, |
|
|
weight=self.f_out_weight, |
|
|
bias=self.f_out_bias, |
|
|
kernel=(5, 5), |
|
|
pad=(2, 2), |
|
|
num_filter=self._L * 2) |
|
|
if cfg.MODEL.TRAJRNN.SAVE_MID_RESULTS: |
|
|
import os |
|
|
flows = save_npy(flows, save_name="%s_flow" %prefix, |
|
|
save_dir=os.path.join(cfg.MODEL.SAVE_DIR, "flows")) |
|
|
flows = mx.sym.split(flows, num_outputs=self._L, axis=1) |
|
|
flows = [flows[i] for i in range(self._L)] |
|
|
return flows |
|
|
|
|
|
def __call__(self, inputs, states=None, is_initial=False, ret_mid=False): |
|
|
self._counter += 1 |
|
|
name = '%s_t%d' % (self._prefix, self._counter) |
|
|
if is_initial: |
|
|
states = self.begin_state()[0] |
|
|
else: |
|
|
states = states[0] |
|
|
assert states is not None |
|
|
if inputs is not None: |
|
|
i2h = mx.sym.Convolution(data=inputs, |
|
|
weight=self.i2h_weight, |
|
|
bias=self.i2h_bias, |
|
|
kernel=self._i2h_kernel, |
|
|
stride=self._i2h_stride, |
|
|
dilate=self._i2h_dilate, |
|
|
pad=self._i2h_pad, |
|
|
num_filter=self._num_filter * 3, |
|
|
name="%s_i2h" % name) |
|
|
i2h_slice = mx.sym.SliceChannel(i2h, num_outputs=3, axis=1) |
|
|
else: |
|
|
i2h_slice = None |
|
|
prev_h = states |
|
|
flows = self._flow_generator(inputs=inputs, states=states, prefix=name) |
|
|
|
|
|
h2h = flow_conv(data=prev_h, num_filter=self._num_filter * 3, flows=flows, |
|
|
weight=self.h2h_weight, bias=self.h2h_bias, name="%s_h2h" % name) |
|
|
h2h_slice = mx.sym.SliceChannel(h2h, num_outputs=3, axis=1) |
|
|
if i2h_slice is not None: |
|
|
reset_gate = mx.sym.Activation(i2h_slice[0] + h2h_slice[0], act_type="sigmoid", |
|
|
name=name + "_r") |
|
|
update_gate = mx.sym.Activation(i2h_slice[1] + h2h_slice[1], act_type="sigmoid", |
|
|
name=name + "_u") |
|
|
new_mem = activation(i2h_slice[2] + reset_gate * h2h_slice[2], |
|
|
act_type=self._act_type, |
|
|
name=name + "_h") |
|
|
else: |
|
|
reset_gate = mx.sym.Activation(h2h_slice[0], act_type="sigmoid", |
|
|
name=name + "_r") |
|
|
update_gate = mx.sym.Activation(h2h_slice[1], act_type="sigmoid", |
|
|
name=name + "_u") |
|
|
new_mem = activation(reset_gate * h2h_slice[2], |
|
|
act_type=self._act_type, |
|
|
name=name + "_h") |
|
|
next_h = update_gate * prev_h + (1 - update_gate) * new_mem |
|
|
if self._zoneout > 0.0: |
|
|
mask = mx.sym.Dropout(mx.sym.ones_like(prev_h), p=self._zoneout) |
|
|
next_h = mx.sym.where(mask, next_h, prev_h) |
|
|
self._curr_states = [next_h] |
|
|
if not ret_mid: |
|
|
return next_h, [next_h] |
|
|
else: |
|
|
return next_h, [next_h], [] |
|
|
|