import mxnet as mx from nowcasting.config import cfg from nowcasting.ops import reset_regs from nowcasting.operators.common import grid_generator class PredictionBaseFactory(object): def __init__(self, batch_size, in_seq_len, out_seq_len, height, width, name="forecaster"): self._out_typ = cfg.MODEL.OUT_TYPE self._batch_size = batch_size self._in_seq_len = in_seq_len self._out_seq_len = out_seq_len self._height = height self._width = width self._name = name self._spatial_grid = grid_generator(batch_size=batch_size, height=height, width=width) self.rnn_list = self._init_rnn() self._reset_rnn() def _pre_encode_frame(self, frame_data, seqlen): ret = mx.sym.Concat(frame_data, mx.sym.broadcast_to(mx.sym.expand_dims(self._spatial_grid, axis=0), shape=(seqlen, self._batch_size, 2, self._height, self._width)), mx.sym.ones(shape=(seqlen, self._batch_size, 1, self._height, self._width)), num_args=3, dim=2) return ret def _init_rnn(self): raise NotImplementedError def _reset_rnn(self): for rnn in self.rnn_list: rnn.reset() def reset_all(self): reset_regs() self._reset_rnn() class RecursiveOneStepBaseFactory(PredictionBaseFactory): def __init__(self, batch_size, in_seq_len, out_seq_len, height, width, use_ss=False, name="forecaster"): super(RecursiveOneStepBaseFactory, self).__init__(batch_size=batch_size, in_seq_len=in_seq_len, out_seq_len=out_seq_len, height=height, width=width, name=name) self._use_ss = False