File size: 8,764 Bytes
6021dd1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import mxnet as mx
import logging
from nowcasting.ops import *
from nowcasting.operators.common import identity, grid_generator, group_add, constant, save_npy
from nowcasting.operators.conv_rnn import BaseConvRNN
import numpy as np
def flow_conv(data, num_filter, flows, weight, bias, name):
assert isinstance(flows, list)
warpped_data = []
for i in range(len(flows)):
flow = flows[i]
grid = mx.sym.GridGenerator(data=-flow, transform_type="warp")
ele_dat = mx.sym.BilinearSampler(data=data, grid=grid)
warpped_data.append(ele_dat)
data = mx.sym.concat(*warpped_data, dim=1)
ret = mx.sym.Convolution(data=data,
num_filter=num_filter,
kernel=(1, 1),
weight=weight,
bias=bias,
name=name)
return ret
class TrajGRU(BaseConvRNN):
def __init__(self, b_h_w, num_filter, zoneout=0.0, L=5,
i2h_kernel=(3, 3), i2h_stride=(1, 1), i2h_pad=(1, 1),
h2h_kernel=(5, 5), h2h_dilate=(1, 1),
act_type="leaky",
prefix="TrajGRU", lr_mult=1.0):
super(TrajGRU, self).__init__(num_filter=num_filter,
b_h_w=b_h_w,
h2h_kernel=h2h_kernel,
h2h_dilate=h2h_dilate,
i2h_kernel=i2h_kernel,
i2h_pad=i2h_pad,
i2h_stride=i2h_stride,
act_type=act_type,
prefix=prefix)
self._L = L
self._zoneout = zoneout
self.i2f_conv1_weight = self.params.get("i2f_conv1_weight", lr_mult=lr_mult)
self.i2f_conv1_bias = self.params.get("i2f_conv1_bias", lr_mult=lr_mult)
self.h2f_conv1_weight = self.params.get("h2f_conv1_weight", lr_mult=lr_mult)
self.h2f_conv1_bias = self.params.get("h2f_conv1_bias", lr_mult=lr_mult)
self.f_conv2_weight = self.params.get("f_conv2_weight", lr_mult=lr_mult)
self.f_conv2_bias = self.params.get("f_conv2_bias", lr_mult=lr_mult)
if cfg.MODEL.TRAJRNN.INIT_GRID:
logging.info("TrajGRU: Initialize Grid Using Zeros!")
self.f_out_weight = self.params.get("f_out_weight",
lr_mult=lr_mult * cfg.MODEL.TRAJRNN.FLOW_LR_MULT,
init=mx.init.Zero())
self.f_out_bias = self.params.get("f_out_bias",
lr_mult=lr_mult * cfg.MODEL.TRAJRNN.FLOW_LR_MULT,
init=mx.init.Zero())
else:
self.f_out_weight = self.params.get("f_out_weight", lr_mult=lr_mult)
self.f_out_bias = self.params.get("f_out_bias", lr_mult=lr_mult)
self.i2h_weight = self.params.get("i2h_weight", lr_mult=lr_mult)
self.i2h_bias = self.params.get("i2h_bias", lr_mult=lr_mult)
self.h2h_weight = self.params.get("h2h_weight", lr_mult=lr_mult)
self.h2h_bias = self.params.get("h2h_bias", lr_mult=lr_mult)
@property
def state_postfix(self):
return ['h']
@property
def state_info(self):
return [{'shape': (self._batch_size, self._num_filter,
self._state_height, self._state_width),
'__layout__': "NCHW"}]
def _flow_generator(self, inputs, states, prefix):
if inputs is not None:
i2f_conv1 = mx.sym.Convolution(data=inputs,
weight=self.i2f_conv1_weight,
bias=self.i2f_conv1_bias,
kernel=(5, 5),
dilate=(1, 1),
pad=(2, 2),
num_filter=32,
name="%s_i2f_conv1" % prefix)
else:
i2f_conv1 = None
h2f_conv1 = mx.sym.Convolution(data=states,
weight=self.h2f_conv1_weight,
bias=self.h2f_conv1_bias,
kernel=(5, 5),
dilate=(1, 1),
pad=(2, 2),
num_filter=32,
name="%s_h2f_conv1" % prefix)
f_conv1 = i2f_conv1 + h2f_conv1 if i2f_conv1 is not None else h2f_conv1
f_conv1 = activation(f_conv1, act_type=self._act_type)
# f_conv2 = mx.sym.Convolution(data=f_conv1,
# weight=self.f_conv2_weight,
# bias=self.f_conv2_bias,
# kernel=(5, 5),
# dilate=(1, 1),
# pad=(2, 2),
# num_filter=32,
# name="%s_f_conv2" %prefix)
# f_conv2 = activation(f_conv2, act_type=self._act_type)
flows = mx.sym.Convolution(data=f_conv1,
weight=self.f_out_weight,
bias=self.f_out_bias,
kernel=(5, 5),
pad=(2, 2),
num_filter=self._L * 2)
if cfg.MODEL.TRAJRNN.SAVE_MID_RESULTS:
import os
flows = save_npy(flows, save_name="%s_flow" %prefix,
save_dir=os.path.join(cfg.MODEL.SAVE_DIR, "flows"))
flows = mx.sym.split(flows, num_outputs=self._L, axis=1)
flows = [flows[i] for i in range(self._L)]
return flows
def __call__(self, inputs, states=None, is_initial=False, ret_mid=False):
self._counter += 1
name = '%s_t%d' % (self._prefix, self._counter)
if is_initial:
states = self.begin_state()[0]
else:
states = states[0]
assert states is not None
if inputs is not None:
i2h = mx.sym.Convolution(data=inputs,
weight=self.i2h_weight,
bias=self.i2h_bias,
kernel=self._i2h_kernel,
stride=self._i2h_stride,
dilate=self._i2h_dilate,
pad=self._i2h_pad,
num_filter=self._num_filter * 3,
name="%s_i2h" % name)
i2h_slice = mx.sym.SliceChannel(i2h, num_outputs=3, axis=1)
else:
i2h_slice = None
prev_h = states
flows = self._flow_generator(inputs=inputs, states=states, prefix=name)
# flows[0] = identity(flows[0], input_debug=True)
h2h = flow_conv(data=prev_h, num_filter=self._num_filter * 3, flows=flows,
weight=self.h2h_weight, bias=self.h2h_bias, name="%s_h2h" % name)
h2h_slice = mx.sym.SliceChannel(h2h, num_outputs=3, axis=1)
if i2h_slice is not None:
reset_gate = mx.sym.Activation(i2h_slice[0] + h2h_slice[0], act_type="sigmoid",
name=name + "_r")
update_gate = mx.sym.Activation(i2h_slice[1] + h2h_slice[1], act_type="sigmoid",
name=name + "_u")
new_mem = activation(i2h_slice[2] + reset_gate * h2h_slice[2],
act_type=self._act_type,
name=name + "_h")
else:
reset_gate = mx.sym.Activation(h2h_slice[0], act_type="sigmoid",
name=name + "_r")
update_gate = mx.sym.Activation(h2h_slice[1], act_type="sigmoid",
name=name + "_u")
new_mem = activation(reset_gate * h2h_slice[2],
act_type=self._act_type,
name=name + "_h")
next_h = update_gate * prev_h + (1 - update_gate) * new_mem
if self._zoneout > 0.0:
mask = mx.sym.Dropout(mx.sym.ones_like(prev_h), p=self._zoneout)
next_h = mx.sym.where(mask, next_h, prev_h)
self._curr_states = [next_h]
if not ret_mid:
return next_h, [next_h]
else:
return next_h, [next_h], []
|