File size: 2,159 Bytes
6021dd1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import mxnet as mx
from nowcasting.config import cfg
from nowcasting.ops import reset_regs
from nowcasting.operators.common import grid_generator
class PredictionBaseFactory(object):
def __init__(self, batch_size, in_seq_len, out_seq_len, height, width, name="forecaster"):
self._out_typ = cfg.MODEL.OUT_TYPE
self._batch_size = batch_size
self._in_seq_len = in_seq_len
self._out_seq_len = out_seq_len
self._height = height
self._width = width
self._name = name
self._spatial_grid = grid_generator(batch_size=batch_size, height=height, width=width)
self.rnn_list = self._init_rnn()
self._reset_rnn()
def _pre_encode_frame(self, frame_data, seqlen):
ret = mx.sym.Concat(frame_data,
mx.sym.broadcast_to(mx.sym.expand_dims(self._spatial_grid, axis=0),
shape=(seqlen, self._batch_size,
2, self._height, self._width)),
mx.sym.ones(shape=(seqlen, self._batch_size, 1,
self._height, self._width)),
num_args=3, dim=2)
return ret
def _init_rnn(self):
raise NotImplementedError
def _reset_rnn(self):
for rnn in self.rnn_list:
rnn.reset()
def reset_all(self):
reset_regs()
self._reset_rnn()
class RecursiveOneStepBaseFactory(PredictionBaseFactory):
def __init__(self, batch_size, in_seq_len, out_seq_len, height, width, use_ss=False,
name="forecaster"):
super(RecursiveOneStepBaseFactory, self).__init__(batch_size=batch_size,
in_seq_len=in_seq_len,
out_seq_len=out_seq_len,
height=height,
width=width,
name=name)
self._use_ss = False
|