STLDM_official / nowcasting /models /deconvolution_symbol.py
sqfoo's picture
Upload 99 files
6021dd1 verified
import mxnet as mx
from nowcasting.ops import \
conv3d, conv3d_act, conv3d_bn_act, \
conv2d, conv2d_act, conv2d_bn_act, \
deconv3d, deconv3d_act, deconv3d_bn_act, \
deconv2d, deconv2d_act, deconv2d_bn_act, \
fc_layer, fc_layer_act
from nowcasting.config import cfg
### Network structure
def encode_net_symbol(data,
data_type,
no_bias=False,
momentum=0.9,
fix_gamma=False,
eps=1e-5 + 1e-12,
postfix=""):
"""Construct encode_net symbol.
Args:
data: input data (context or pred)
data_type: If "context" use IN_LEN, if "pred" use OUT_LEN, if
"contextpred" use IN_LEN + OUT_LEN.
postfix: Postfix for symbol names. Parameters will be shared with
between symbols created during calls to encode_net_symbol with same
data_type and postfix argument,
"""
if cfg.DATASET == "MOVINGMNIST":
IN_LEN = cfg.MOVINGMNIST.IN_LEN
OUT_LEN = cfg.MOVINGMNIST.OUT_LEN
IMG_SIZE = cfg.MOVINGMNIST.IMG_SIZE
elif cfg.DATASET == "HKO":
IN_LEN = cfg.HKO.BENCHMARK.IN_LEN
OUT_LEN = cfg.HKO.BENCHMARK.OUT_LEN
IMG_SIZE = cfg.HKO.ITERATOR.WIDTH
# Input
# (cfg.TRAIN.BATCH_SIZE, 1, IN_LEN, IMG_SIZE, IMG_SIZE)
# Determine length.
if data_type == "context":
length = IN_LEN
elif data_type == "pred":
length = OUT_LEN
elif data_type == "contextpred":
length = IN_LEN + OUT_LEN
else:
raise NotImplementedError
# Postfix for symbol names.
postfix = "_" + data_type + "_" + postfix
if not cfg.MODEL.DECONVBASELINE.USE_3D:
data = mx.sym.reshape(
data,
shape=(cfg.MODEL.TRAIN.BATCH_SIZE, length, IMG_SIZE, IMG_SIZE))
# Assertions
if cfg.DATASET == "MOVINGMNIST":
assert (length in [1, 10, 11, 20])
elif cfg.DATASET == "HKO":
assert (length in [1, 5, 20, 21, 25])
k = [1, 1, 1]
s = [1, 1, 1]
p = [0, 0, 0]
if cfg.DATASET == "HKO" or (cfg.DATASET == "MOVINGMNIST" and length == 20):
# For MOVINGMNIST, if data_type == contextpred and OUT_LEN == 10 frames,
# i.e. length == 20. If length == 11 we don't need this.
# For HKO, if length in [20, 21, 25]
if length > 11:
k[0] = 4
s[0] = 2
p[0] = 1
# For HKO, IMG_SIZE == 480, we scale it down to 96
if cfg.DATASET == "HKO":
k[1:] = [7, 7]
s[1:] = [5, 5]
p[1:] = [1, 1]
data = conv2d_3d_act(
use_3d=cfg.MODEL.DECONVBASELINE.USE_3D,
data=data,
name='encode_net_0' + postfix,
act_type=cfg.MODEL.CNN_ACT_TYPE,
kernel=k,
stride=s,
pad=p,
num_filter=cfg.MODEL.DECONVBASELINE.BASE_NUM_FILTER,
no_bias=no_bias)
# Set convolution parameters for height and width
k[1:] = [4, 4]
s[1:] = [2, 2]
p[1:] = [1, 1]
# Set convolution parameters for sequence length
# I.e. start reducing sequence length, if input length >= 10
if length >= 10:
k[0] = 4
s[0] = 2
p[0] = 1
# For HKO the HEIGHT and WIDTH is still 96,
# we therefore increase stride to 3
if cfg.DATASET == "HKO":
s[1:] = [3, 3]
e1 = conv2d_3d_act(
use_3d=cfg.MODEL.DECONVBASELINE.USE_3D,
data=data,
name='encode_net_1' + postfix,
act_type=cfg.MODEL.CNN_ACT_TYPE,
kernel=k,
stride=s,
pad=p,
num_filter=cfg.MODEL.DECONVBASELINE.BASE_NUM_FILTER,
no_bias=no_bias)
# Set convolution parameters for sequence length
# I.e. start reducing sequence length, if input length >= 5
if length >= 5:
k[0] = 4
s[0] = 2
p[0] = 1
# Reset stride if previously changed
if cfg.DATASET == "HKO":
s[1:] = [2, 2]
e2 = conv2d_3d_bn_act(
use_global_stats=cfg.MODEL.DECONVBASELINE.BN_GLOBAL_STATS,
use_3d=cfg.MODEL.DECONVBASELINE.USE_3D,
use_bn=cfg.MODEL.DECONVBASELINE.BN,
data=e1,
name='encode_net_2' + postfix,
act_type=cfg.MODEL.CNN_ACT_TYPE,
kernel=k,
stride=s,
pad=p,
num_filter=cfg.MODEL.DECONVBASELINE.BASE_NUM_FILTER * 2,
no_bias=no_bias,
height=16,
width=16,
fix_gamma=fix_gamma,
eps=eps,
momentum=momentum)
e3 = conv2d_3d_bn_act(
use_global_stats=cfg.MODEL.DECONVBASELINE.BN_GLOBAL_STATS,
use_3d=cfg.MODEL.DECONVBASELINE.USE_3D,
use_bn=cfg.MODEL.DECONVBASELINE.BN,
data=e2,
name='encode_net_3' + postfix,
act_type=cfg.MODEL.CNN_ACT_TYPE,
kernel=k,
stride=s,
pad=p,
num_filter=cfg.MODEL.DECONVBASELINE.BASE_NUM_FILTER * 3,
no_bias=no_bias,
height=8,
width=8,
fix_gamma=fix_gamma,
eps=eps,
momentum=momentum)
# Increase padding for sequence length
p[0] = 2
e4 = conv2d_3d_bn_act(
use_global_stats=cfg.MODEL.DECONVBASELINE.BN_GLOBAL_STATS,
use_3d=cfg.MODEL.DECONVBASELINE.USE_3D,
use_bn=cfg.MODEL.DECONVBASELINE.BN,
data=e3,
name='encode_net_4' + postfix,
act_type=cfg.MODEL.CNN_ACT_TYPE,
kernel=k,
stride=s,
pad=p,
num_filter=cfg.MODEL.DECONVBASELINE.BASE_NUM_FILTER * 4,
no_bias=no_bias,
height=4,
width=4,
fix_gamma=fix_gamma,
eps=eps,
momentum=momentum)
# Output
# (batch_size, 4 * num_filter, 1, 4, 4)
# or in 2D case
# (batch_size, 4 * num_filter, 4, 4)
return e4
def video_net_symbol(encode_net,
no_bias=False,
momentum=0.9,
fix_gamma=False,
eps=1e-5 + 1e-12):
if cfg.DATASET == "MOVINGMNIST":
OUT_LEN = cfg.MOVINGMNIST.OUT_LEN
elif cfg.DATASET == "HKO":
OUT_LEN = cfg.HKO.BENCHMARK.OUT_LEN
# Input
# (batch_size, num_filter * 4, 1, 4, 4)
# or in 2D case
# (batch_size, 4 * num_filter, 4, 4)
assert (OUT_LEN in [1, 10, 20])
k = [1, 1, 1]
s = [1, 1, 1]
p = [0, 0, 0]
if OUT_LEN > 1:
k[0] = 2
d1 = deconv2d_3d_act(
use_3d=cfg.MODEL.DECONVBASELINE.USE_3D,
data=encode_net,
name='video_net_d1',
kernel=k,
stride=s,
pad=p,
num_filter=cfg.MODEL.DECONVBASELINE.BASE_NUM_FILTER * 8,
act_type=cfg.MODEL.CNN_ACT_TYPE,
no_bias=no_bias)
k[1:] = [4, 4]
s[1:] = [2, 2]
p[1:] = [1, 1]
if OUT_LEN >= 10:
k[0] = 4
s[0] = 2
p[0] = 1
d2 = deconv2d_3d_bn_act(
use_global_stats=cfg.MODEL.DECONVBASELINE.BN_GLOBAL_STATS,
use_3d=cfg.MODEL.DECONVBASELINE.USE_3D,
use_bn=cfg.MODEL.DECONVBASELINE.BN,
act_type=cfg.MODEL.CNN_ACT_TYPE,
data=d1,
name='video_net_d2',
kernel=k,
stride=s,
pad=p,
num_filter=cfg.MODEL.DECONVBASELINE.BASE_NUM_FILTER * 4,
no_bias=no_bias,
height=8,
width=8,
fix_gamma=fix_gamma,
eps=eps,
momentum=momentum)
if OUT_LEN == 10:
p[0] = 2
elif OUT_LEN == 20:
p[0] = 0
d3 = deconv2d_3d_bn_act(
use_global_stats=cfg.MODEL.DECONVBASELINE.BN_GLOBAL_STATS,
use_3d=cfg.MODEL.DECONVBASELINE.USE_3D,
use_bn=cfg.MODEL.DECONVBASELINE.BN,
act_type=cfg.MODEL.CNN_ACT_TYPE,
data=d2,
name='video_net_d3',
kernel=k,
stride=s,
pad=p,
num_filter=cfg.MODEL.DECONVBASELINE.BASE_NUM_FILTER * 2,
no_bias=no_bias,
height=16,
width=16,
fix_gamma=fix_gamma,
eps=eps,
momentum=momentum)
if OUT_LEN == 20:
p[0] = 1
d4 = deconv2d_3d_bn_act(
use_global_stats=cfg.MODEL.DECONVBASELINE.BN_GLOBAL_STATS,
use_3d=cfg.MODEL.DECONVBASELINE.USE_3D,
use_bn=cfg.MODEL.DECONVBASELINE.BN,
act_type=cfg.MODEL.CNN_ACT_TYPE,
data=d3,
name='video_net_d4',
kernel=k,
stride=s,
pad=p,
num_filter=cfg.MODEL.DECONVBASELINE.BASE_NUM_FILTER,
no_bias=no_bias,
height=32,
width=32,
fix_gamma=fix_gamma,
eps=eps,
momentum=momentum)
out_filter = 1
if OUT_LEN > 1:
k[0] = 3
s[0] = 1
p[0] = 1
# For HKO, scale up to 96 instead of 64
if cfg.DATASET == "HKO":
k[1:] = [5, 5]
s[1:] = [3, 3]
p[1:] = [1, 1]
out_filter = 8
if cfg.MODEL.DECONVBASELINE.USE_3D:
gen_out = mx.sym.Deconvolution(
data=d4,
name='gen_out',
kernel=k,
stride=s,
pad=p,
# Generate grayscale video with only 1 channel
num_filter=out_filter,
no_bias=no_bias)
else:
gen_out = mx.sym.Deconvolution(
data=d4,
name='gen_out',
kernel=k[1:],
stride=s[1:],
pad=p[1:],
# Generate grayscale video with only 1 channel
num_filter=OUT_LEN * out_filter,
no_bias=no_bias)
# For HKO we need to scale up further from 96 to 480
if cfg.DATASET == "HKO":
k[1:] = [7, 7]
s[1:] = [5, 5]
p[1:] = [1, 1]
if cfg.MODEL.DECONVBASELINE.USE_3D:
gen_out = mx.sym.Deconvolution(
data=gen_out,
name='gen_out_scale',
kernel=k,
stride=s,
pad=p,
# Generate grayscale video with only 1 channel
num_filter=1 * out_filter,
no_bias=no_bias)
else:
gen_out = mx.sym.Deconvolution(
data=gen_out,
name='gen_out_scale',
kernel=k[1:],
stride=s[1:],
pad=p[1:],
# Generate grayscale video with only 1 channel
num_filter=OUT_LEN * out_filter,
no_bias=no_bias)
# For HKO we add a final refinement layer
if cfg.DATASET == "HKO":
k[1:] = [3, 3]
s[1:] = [1, 1]
p[1:] = [1, 1]
if cfg.MODEL.DECONVBASELINE.USE_3D:
gen_out = mx.sym.Deconvolution(
data=gen_out,
name='gen_out_scale2',
kernel=k,
stride=s,
pad=p,
# Generate grayscale video with only 1 channel
num_filter=1,
no_bias=no_bias)
else:
gen_out = mx.sym.Deconvolution(
data=gen_out,
name='gen_out_scale2',
kernel=k[1:],
stride=s[1:],
pad=p[1:],
# Generate grayscale video with only 1 channel
num_filter=OUT_LEN,
no_bias=no_bias)
# Output
# gen_out (batch_size, 1, 10, 64, 64)
return gen_out
def generator_symbol(context,
no_bias=False,
momentum=0.9,
fix_gamma=False,
eps=1e-5 + 1e-12):
encode_net = encode_net_symbol(
data=context,
data_type="context",
no_bias=no_bias,
momentum=momentum,
fix_gamma=fix_gamma,
eps=eps)
if cfg.MODEL.DECONVBASELINE.FC_BETWEEN_ENCDEC:
encode_net = mx.sym.FullyConnected(
data=encode_net,
num_hidden=cfg.MODEL.DECONVBASELINE.FC_BETWEEN_ENCDEC)
if cfg.MODEL.DECONVBASELINE.USE_3D:
encode_net = mx.sym.Reshape(
data=encode_net, shape=(cfg.MODEL.TRAIN.BATCH_SIZE, -1, 1, 4, 4))
else:
encode_net = mx.sym.Reshape(
data=encode_net, shape=(cfg.MODEL.TRAIN.BATCH_SIZE, -1, 4, 4))
gen_net = video_net_symbol(
encode_net,
no_bias=no_bias,
momentum=momentum,
fix_gamma=fix_gamma,
eps=eps)
if cfg.DATASET == "MOVINGMNIST":
OUT_LEN = cfg.MOVINGMNIST.OUT_LEN
IMG_SIZE = cfg.MOVINGMNIST.IMG_SIZE
elif cfg.DATASET == "HKO":
OUT_LEN = cfg.HKO.BENCHMARK.OUT_LEN
IMG_SIZE = cfg.HKO.ITERATOR.WIDTH
# No operation if cfg.MODEL.DECONVBASELINE.USE_3D is True
gen_net = mx.sym.reshape(
gen_net,
shape=(cfg.MODEL.TRAIN.BATCH_SIZE, 1, OUT_LEN, IMG_SIZE, IMG_SIZE),
name="pred")
return mx.sym.Group([
gen_net,
mx.sym.BlockGrad(
mx.sym.clip(gen_net, a_min=0, a_max=1), name="forecast_target")
])
def discriminator_symbol(context,
pred,
no_bias=False,
momentum=0.9,
fix_gamma=False,
eps=1e-5 + 1e-12):
# context: (batch_size, 1, input_length, 64, 64)
# pred: (batch_size, 1, output_length, 64, 64)
if cfg.DATASET == "MOVINGMNIST":
OUT_LEN = cfg.MOVINGMNIST.OUT_LEN
elif cfg.DATASET == "HKO":
OUT_LEN = cfg.HKO.BENCHMARK.OUT_LEN
mask = mx.sym.Variable('mask')
pred = pred * mask
if cfg.MODEL.DECONVBASELINE.ENCODER in ["shared", "separate"]:
postfix = "" if cfg.MODEL.DECONVBASELINE.ENCODER == "shared" else "_gan"
context_encoding = encode_net_symbol(
data=context,
data_type="context",
no_bias=no_bias,
momentum=momentum,
fix_gamma=fix_gamma,
eps=eps,
postfix=postfix)
pred_encoding = encode_net_symbol(
data=pred,
data_type="pred",
no_bias=no_bias,
momentum=momentum,
fix_gamma=fix_gamma,
eps=eps)
# context_encoding: (batch_size, 4 * num_filter, 1, 4, 4)
# pred_encoding: (batch_size, 4 * num_filter, 1, 4, 4)
if cfg.MODEL.DECONVBASELINE.USE_3D:
context_pred = mx.sym.concat(
context_encoding, pred_encoding, dim=2)
else:
context_pred = mx.sym.concat(
context_encoding, pred_encoding, dim=1)
# Compatibility code
if cfg.MODEL.DECONVBASELINE.COMPAT.CONV_INSTEADOF_FC_IN_ENCODER:
# Introduce extra layer to merge context and pred representations
d5 = conv2d_3d_bn_act(
use_global_stats=cfg.MODEL.DECONVBASELINE.BN_GLOBAL_STATS,
use_3d=cfg.MODEL.DECONVBASELINE.USE_3D,
use_bn=cfg.MODEL.DECONVBASELINE.BN,
data=context_pred,
name='discriminator_5',
act_type=cfg.MODEL.CNN_ACT_TYPE,
kernel=(1, 1, 1),
stride=(1, 1, 1),
pad=(0, 0, 0),
num_filter=cfg.MODEL.DECONVBASELINE.BASE_NUM_FILTER,
no_bias=no_bias,
height=4,
width=4,
fix_gamma=fix_gamma,
eps=eps,
momentum=momentum)
d6 = conv2d_3d(
use_3d=cfg.MODEL.DECONVBASELINE.USE_3D,
data=d5,
name='discriminator_6',
kernel=(1, 4, 4),
stride=(1, 1, 1),
pad=(0, 0, 0),
num_filter=1,
no_bias=no_bias)
return mx.sym.Flatten(d6)
else:
# flattened_encoding: (batch_size, num_filter * 4^3)
flattened_encoding = mx.sym.Flatten(data=context_pred)
elif cfg.MODEL.DECONVBASELINE.ENCODER == "concat":
context_pred = mx.sym.concat(context, pred, dim=2)
encoding = encode_net_symbol(
data=context_pred,
data_type="contextpred",
no_bias=no_bias,
momentum=momentum,
fix_gamma=fix_gamma,
eps=eps)
flattened_encoding = mx.sym.Flatten(data=encoding)
else:
raise NotImplementedError
fc1 = fc_layer_act(
data=flattened_encoding,
num_hidden=256,
name="discriminator_fc_1",
act_type=cfg.MODEL.CNN_ACT_TYPE)
return fc_layer(data=fc1, num_hidden=1, name="discriminator_fc_2")
### Helpers
def batchnorm_5d(data, height, width, name, fix_gamma, eps, momentum):
data = mx.symbol.reshape(data, shape=(0, 0, -1, width))
data = mx.sym.BatchNorm(
data,
name=name,
fix_gamma=fix_gamma,
eps=eps,
momentum=momentum,
use_global_stats=cfg.MODEL.DECONVBASELINE.BN_GLOBAL_STATS)
return mx.symbol.reshape(data, shape=(0, 0, -1, height, width))
def conv2d_3d(data,
num_filter,
kernel=(1, 1, 1),
stride=(1, 1, 1),
pad=(0, 0, 0),
dilate=(1, 1, 1),
no_bias=False,
name=None,
use_3d=True,
**kwargs):
"""If use_3d == False use a 2D convolution with the same number of parameters."""
if use_3d:
return conv3d(
data=data,
num_filter=num_filter,
kernel=kernel,
stride=stride,
pad=pad,
dilate=dilate,
no_bias=no_bias,
name=name,
**kwargs)
else:
return conv2d(
data=data,
num_filter=num_filter * kernel[0],
kernel=kernel[1:],
stride=stride[1:],
pad=pad[1:],
dilate=dilate[1:],
no_bias=no_bias,
name=name,
**kwargs)
def conv2d_3d_bn_act(data,
num_filter,
height,
width,
kernel=(1, 1, 1),
stride=(1, 1, 1),
pad=(0, 0, 0),
dilate=(1, 1, 1),
no_bias=False,
act_type="relu",
momentum=0.9,
eps=1e-5 + 1e-12,
fix_gamma=True,
name=None,
use_3d=True,
use_bn=True,
use_global_stats=False,
**kwargs):
"""If use_3d == False use a 2D convolution with the same number of parameters."""
if not use_bn:
return conv2d_3d_act(
data=data,
num_filter=num_filter,
kernel=kernel,
stride=stride,
pad=pad,
dilate=dilate,
no_bias=no_bias,
act_type=act_type,
name=name,
use_3d=use_3d)
if use_3d:
return conv3d_bn_act(
data=data,
num_filter=num_filter,
height=height,
width=width,
kernel=kernel,
stride=stride,
pad=pad,
dilate=dilate,
no_bias=no_bias,
act_type=act_type,
momentum=momentum,
eps=eps,
fix_gamma=fix_gamma,
name=name,
use_global_stats=use_global_stats,
**kwargs)
else:
return conv2d_bn_act(
data=data,
num_filter=num_filter * kernel[0],
kernel=kernel[1:],
stride=stride[1:],
pad=pad[1:],
dilate=dilate[1:],
no_bias=no_bias,
act_type=act_type,
momentum=momentum,
eps=eps,
fix_gamma=fix_gamma,
name=name,
use_global_stats=use_global_stats,
**kwargs)
def conv2d_3d_act(data,
num_filter,
kernel=(1, 1, 1),
stride=(1, 1, 1),
pad=(0, 0, 0),
dilate=(1, 1, 1),
no_bias=False,
act_type="relu",
name=None,
use_3d=True,
**kwargs):
"""If use_3d == False use a 2D convolution with the same number of parameters."""
if use_3d:
return conv3d_act(
data=data,
num_filter=num_filter,
kernel=kernel,
stride=stride,
pad=pad,
dilate=dilate,
no_bias=no_bias,
act_type=act_type,
name=name,
**kwargs)
else:
return conv2d_act(
data=data,
num_filter=num_filter * kernel[0],
kernel=kernel[1:],
stride=stride[1:],
pad=pad[1:],
dilate=dilate[1:],
no_bias=no_bias,
act_type=act_type,
name=name,
**kwargs)
def deconv2d_3d(data,
num_filter,
kernel=(1, 1, 1),
stride=(1, 1, 1),
pad=(0, 0, 0),
adj=(0, 0, 0),
no_bias=True,
target_shape=None,
name=None,
use_3d=True,
**kwargs):
"""If use_3d == False use a 2D deconvolution with the same number of parameters."""
if use_3d:
return deconv3d_act(
data=data,
num_filter=num_filter,
kernel=kernel,
stride=stride,
pad=pad,
adj=adj,
no_bias=no_bias,
target_shape=target_shape,
act_type=act_type,
name=name,
**kwargs)
else:
return deconv2d_act(
data=data,
num_filter=num_filter * kernel[0],
kernel=kernel[1:],
stride=stride[1:],
pad=pad[1:],
adj=adj[1:],
no_bias=no_bias,
target_shape=target_shape,
act_type=act_type,
name=name,
**kwargs)
def deconv2d_3d_bn_act(data,
num_filter,
height,
width,
kernel=(1, 1, 1),
stride=(1, 1, 1),
pad=(0, 0, 0),
adj=(0, 0, 0),
no_bias=True,
target_shape=None,
act_type="relu",
momentum=0.9,
eps=1e-5 + 1e-12,
fix_gamma=True,
name=None,
use_3d=True,
use_bn=True,
use_global_stats=False,
**kwargs):
"""If use_3d == False use a 2D deconvolution with the same number of parameters."""
if not use_bn:
return deconv2d_3d_act(
data=data,
num_filter=num_filter,
kernel=kernel,
stride=stride,
pad=pad,
adj=adj,
no_bias=no_bias,
act_type=act_type,
name=name,
use_3d=use_3d, )
if use_3d:
return deconv3d_bn_act(
data=data,
num_filter=num_filter,
height=height,
width=width,
kernel=kernel,
stride=stride,
pad=pad,
adj=adj,
no_bias=no_bias,
target_shape=target_shape,
act_type=act_type,
momentum=momentum,
eps=eps,
fix_gamma=fix_gamma,
name=name,
use_global_stats=use_global_stats,
**kwargs)
else:
return deconv2d_bn_act(
data=data,
num_filter=num_filter * kernel[0],
kernel=kernel[1:],
stride=stride[1:],
pad=pad[1:],
adj=adj[1:],
no_bias=no_bias,
target_shape=target_shape,
act_type=act_type,
momentum=momentum,
eps=eps,
fix_gamma=fix_gamma,
name=name,
use_global_stats=use_global_stats,
**kwargs)
def deconv2d_3d_act(data,
num_filter,
kernel=(1, 1, 1),
stride=(1, 1, 1),
pad=(0, 0, 0),
adj=(0, 0, 0),
no_bias=True,
target_shape=None,
act_type="relu",
name=None,
use_3d=True,
**kwargs):
"""If use_3d == False use a 2D deconvolution with the same number of parameters."""
if use_3d:
return deconv3d_act(
data=data,
num_filter=num_filter,
kernel=kernel,
stride=stride,
pad=pad,
adj=adj,
no_bias=no_bias,
target_shape=target_shape,
act_type=act_type,
name=name,
**kwargs)
else:
return deconv2d_act(
data=data,
num_filter=num_filter * kernel[0],
kernel=kernel[1:],
stride=stride[1:],
pad=pad[1:],
adj=adj[1:],
no_bias=no_bias,
target_shape=target_shape,
act_type=act_type,
name=name,
**kwargs)