|
|
import mxnet as mx |
|
|
import numpy as np |
|
|
from nowcasting.operators.common import constant |
|
|
|
|
|
|
|
|
def CDNA(data, kernels, mask, batch_size, num_filter, kernel_size): |
|
|
"""We assume that the kernels and masks are the output of an identity activation |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
data : mx.sym.symbol |
|
|
Shape: (batch_size, C, H, W) |
|
|
kernels : mx.sym.symbol |
|
|
Shape: (batch_size, M, K, K) |
|
|
mask : mx.sym.symbol |
|
|
Shape: (batch_size, M, H, W) |
|
|
batch_size : int |
|
|
num_filter : int |
|
|
M |
|
|
kernel_size : int |
|
|
K |
|
|
Returns |
|
|
------- |
|
|
ret : mx.sym.symbol |
|
|
Shape: (batch_size, C, H, W) |
|
|
""" |
|
|
assert kernel_size % 2 == 1, "Only support odd kernel size" |
|
|
|
|
|
kernels = mx.sym.SoftmaxActivation(mx.sym.Reshape(kernels, |
|
|
shape=(-1, kernel_size * kernel_size))) |
|
|
kernels = mx.sym.Reshape(kernels, shape=(-1, num_filter, kernel_size, kernel_size)) |
|
|
mask = mx.sym.SoftmaxActivation(mask, mode="channel") |
|
|
|
|
|
data_sliced = mx.sym.SliceChannel(mx.sym.expand_dims(data, axis=2), axis=0, |
|
|
num_outputs=batch_size, squeeze_axis=True) |
|
|
kernels_sliced = mx.sym.SliceChannel(mx.sym.expand_dims(kernels, axis=2), |
|
|
axis=0, num_outputs=batch_size, |
|
|
squeeze_axis=True) |
|
|
out = [] |
|
|
for i in range(batch_size): |
|
|
ele = mx.sym.Convolution(data=data_sliced[i], |
|
|
num_filter=num_filter, |
|
|
kernel=(kernel_size, kernel_size), |
|
|
pad=(kernel_size/2, kernel_size/2), |
|
|
weight=kernels_sliced[i], no_bias=True) |
|
|
out.append(mx.sym.expand_dims(ele, axis=0)) |
|
|
out = mx.sym.Concat(*out, num_args=batch_size, dim=0) |
|
|
mask = mx.sym.Reshape(mask, reverse=True, shape=(batch_size, 1, num_filter, 0, 0)) |
|
|
out = mx.sym.broadcast_mul(out, mask) |
|
|
out = mx.sym.sum(out, axis=2) |
|
|
return out |
|
|
|
|
|
def STP(data, affine_transform_matrices, mask, num_filter, kernel_size): |
|
|
"""Spatial Transformer Predictor |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
data : mx.sym.symbol |
|
|
affine_transform_matrices |
|
|
mask |
|
|
|
|
|
Returns |
|
|
------- |
|
|
|
|
|
""" |
|
|
raise NotImplementedError() |
|
|
|
|
|
|
|
|
def DFN(data, local_kernels, K, batch_size): |
|
|
"""[NIPS2016] Dynamic Filter Network |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
data : mx.sym.symbol |
|
|
Shape: (batch_size, C, H, W) |
|
|
local_kernels : mx.sym.symbol |
|
|
Shape: (batch_size, K*K, H, W) |
|
|
K : int |
|
|
size of the local convolutional kernel |
|
|
batch_size : int |
|
|
size of the minibatch |
|
|
Returns |
|
|
------- |
|
|
|
|
|
""" |
|
|
local_kernels = mx.sym.SoftmaxActivation(local_kernels, mode="channel") |
|
|
|
|
|
|
|
|
filter_localexpand = mx.sym.one_hot(indices=mx.sym.arange(K * K), depth=K*K) |
|
|
filter_localexpand = mx.sym.reshape(mx.sym.transpose(filter_localexpand, axes=(1, 0)), |
|
|
shape=(K * K, 1, K, K)) |
|
|
data_sliced = mx.sym.SliceChannel(data, num_outputs=batch_size, axis=0, squeeze_axis=True) |
|
|
vec = [] |
|
|
for i in range(batch_size): |
|
|
ele = mx.sym.Convolution(data=mx.sym.expand_dims(data_sliced[i], axis=1), |
|
|
weight=filter_localexpand, |
|
|
num_filter=K*K, |
|
|
kernel=(K, K), |
|
|
pad=(K // 2, K // 2), no_bias=True) |
|
|
vec.append(mx.sym.expand_dims(ele, axis=0)) |
|
|
input_localexpanded = mx.sym.Concat(*vec, num_args=len(vec), dim=0) |
|
|
output = mx.sym.broadcast_mul(input_localexpanded, mx.sym.expand_dims(local_kernels, axis=1)) |
|
|
output = mx.sym.sum(output, axis=2) |
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
data = mx.sym.Variable('data') |
|
|
local_kernels = mx.sym.Variable('local_kernels') |
|
|
K = 11 |
|
|
C = 3 |
|
|
H = 60 |
|
|
W = 60 |
|
|
batch_size = 32 |
|
|
local_kernels_npy = np.random.normal(size=(batch_size, K*K, H, W)) |
|
|
data_npy = np.random.normal(size=(batch_size, C, H, W)) |
|
|
out = data |
|
|
for i in range(10): |
|
|
out = DFN(data=out, local_kernels=local_kernels, K=K, batch_size=batch_size) |
|
|
exe = out.simple_bind(ctx=mx.gpu(), data=(batch_size, C, H, W), |
|
|
local_kernels=(batch_size, K*K, H, W)) |
|
|
exe.forward(data=data_npy, local_kernels=local_kernels_npy) |
|
|
print(exe.outputs[0].asnumpy().shape) |
|
|
|