|
|
import torch |
|
|
import torch.nn as nn |
|
|
from numpy import * |
|
|
from numpy.linalg import * |
|
|
from scipy.special import factorial |
|
|
from functools import reduce |
|
|
|
|
|
__all__ = ['M2K','K2M'] |
|
|
|
|
|
|
|
|
class PhyCell_Cell(nn.Module): |
|
|
|
|
|
def __init__(self, input_dim, F_hidden_dim, kernel_size, bias=1): |
|
|
super(PhyCell_Cell, self).__init__() |
|
|
self.input_dim = input_dim |
|
|
self.F_hidden_dim = F_hidden_dim |
|
|
self.kernel_size = kernel_size |
|
|
self.padding = kernel_size[0] // 2, kernel_size[1] // 2 |
|
|
self.bias = bias |
|
|
|
|
|
self.F = nn.Sequential() |
|
|
self.F.add_module('conv1', nn.Conv2d(in_channels=input_dim, out_channels=F_hidden_dim, |
|
|
kernel_size=self.kernel_size, stride=(1,1), padding=self.padding)) |
|
|
self.F.add_module('bn1',nn.GroupNorm(7 ,F_hidden_dim)) |
|
|
self.F.add_module('conv2', nn.Conv2d(in_channels=F_hidden_dim, out_channels=input_dim, |
|
|
kernel_size=(1,1), stride=(1,1), padding=(0,0))) |
|
|
|
|
|
self.convgate = nn.Conv2d(in_channels=self.input_dim + self.input_dim, |
|
|
out_channels=self.input_dim, |
|
|
kernel_size=(3,3), |
|
|
padding=(1,1), bias=self.bias) |
|
|
|
|
|
def forward(self, x, hidden): |
|
|
combined = torch.cat([x, hidden], dim=1) |
|
|
combined_conv = self.convgate(combined) |
|
|
K = torch.sigmoid(combined_conv) |
|
|
hidden_tilde = hidden + self.F(hidden) |
|
|
next_hidden = hidden_tilde + K * (x-hidden_tilde) |
|
|
return next_hidden |
|
|
|
|
|
|
|
|
class PhyCell(nn.Module): |
|
|
|
|
|
def __init__(self, input_shape, input_dim, F_hidden_dims, n_layers, kernel_size, device): |
|
|
super(PhyCell, self).__init__() |
|
|
self.input_shape = input_shape |
|
|
self.input_dim = input_dim |
|
|
self.F_hidden_dims = F_hidden_dims |
|
|
self.n_layers = n_layers |
|
|
self.kernel_size = kernel_size |
|
|
self.H = [] |
|
|
self.device = device |
|
|
|
|
|
cell_list = [] |
|
|
for i in range(0, self.n_layers): |
|
|
cell_list.append(PhyCell_Cell(input_dim=input_dim, |
|
|
F_hidden_dim=self.F_hidden_dims[i], |
|
|
kernel_size=self.kernel_size)) |
|
|
self.cell_list = nn.ModuleList(cell_list) |
|
|
|
|
|
|
|
|
def forward(self, input_, first_timestep=False): |
|
|
batch_size = input_.data.size()[0] |
|
|
if (first_timestep): |
|
|
self.initHidden(batch_size) |
|
|
for j, cell in enumerate(self.cell_list): |
|
|
self.H[j] = self.H[j].to(input_.device) |
|
|
if j==0: |
|
|
self.H[j] = cell(input_, self.H[j]) |
|
|
else: |
|
|
self.H[j] = cell(self.H[j-1],self.H[j]) |
|
|
return self.H, self.H |
|
|
|
|
|
def initHidden(self, batch_size): |
|
|
self.H = [] |
|
|
for i in range(self.n_layers): |
|
|
self.H.append(torch.zeros( |
|
|
batch_size, self.input_dim, self.input_shape[0], self.input_shape[1]).to(self.device)) |
|
|
|
|
|
def setHidden(self, H): |
|
|
self.H = H |
|
|
|
|
|
|
|
|
class PhyD_ConvLSTM_Cell(nn.Module): |
|
|
def __init__(self, input_shape, input_dim, hidden_dim, kernel_size, bias=1): |
|
|
""" |
|
|
input_shape: (int, int) |
|
|
Height and width of input tensor as (height, width). |
|
|
input_dim: int |
|
|
Number of channels of input tensor. |
|
|
hidden_dim: int |
|
|
Number of channels of hidden state. |
|
|
kernel_size: (int, int) |
|
|
Size of the convolutional kernel. |
|
|
bias: bool |
|
|
Whether or not to add the bias. |
|
|
""" |
|
|
super(PhyD_ConvLSTM_Cell, self).__init__() |
|
|
|
|
|
self.height, self.width = input_shape |
|
|
self.input_dim = input_dim |
|
|
self.hidden_dim = hidden_dim |
|
|
self.kernel_size = kernel_size |
|
|
self.padding = kernel_size[0] // 2, kernel_size[1] // 2 |
|
|
self.bias = bias |
|
|
|
|
|
self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim, |
|
|
out_channels=4 * self.hidden_dim, |
|
|
kernel_size=self.kernel_size, |
|
|
padding=self.padding, bias=self.bias) |
|
|
|
|
|
|
|
|
def forward(self, x, hidden): |
|
|
h_cur, c_cur = hidden |
|
|
|
|
|
combined = torch.cat([x, h_cur], dim=1) |
|
|
combined_conv = self.conv(combined) |
|
|
cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) |
|
|
i = torch.sigmoid(cc_i) |
|
|
f = torch.sigmoid(cc_f) |
|
|
o = torch.sigmoid(cc_o) |
|
|
g = torch.tanh(cc_g) |
|
|
|
|
|
c_next = f * c_cur + i * g |
|
|
h_next = o * torch.tanh(c_next) |
|
|
return h_next, c_next |
|
|
|
|
|
|
|
|
class PhyD_ConvLSTM(nn.Module): |
|
|
|
|
|
def __init__(self, input_shape, input_dim, hidden_dims, n_layers, kernel_size, device): |
|
|
super(PhyD_ConvLSTM, self).__init__() |
|
|
self.input_shape = input_shape |
|
|
self.input_dim = input_dim |
|
|
self.hidden_dims = hidden_dims |
|
|
self.n_layers = n_layers |
|
|
self.kernel_size = kernel_size |
|
|
self.H, self.C = [], [] |
|
|
self.device = device |
|
|
|
|
|
cell_list = [] |
|
|
for i in range(0, self.n_layers): |
|
|
cur_input_dim = self.input_dim if i == 0 else self.hidden_dims[i-1] |
|
|
print('layer ', i, 'input dim ', cur_input_dim, ' hidden dim ', self.hidden_dims[i]) |
|
|
cell_list.append(PhyD_ConvLSTM_Cell(input_shape=self.input_shape, |
|
|
input_dim=cur_input_dim, |
|
|
hidden_dim=self.hidden_dims[i], |
|
|
kernel_size=self.kernel_size)) |
|
|
self.cell_list = nn.ModuleList(cell_list) |
|
|
|
|
|
def forward(self, input_, first_timestep=False): |
|
|
batch_size = input_.data.size()[0] |
|
|
if (first_timestep): |
|
|
self.initHidden(batch_size) |
|
|
for j, cell in enumerate(self.cell_list): |
|
|
self.H[j] = self.H[j].to(input_.device) |
|
|
self.C[j] = self.C[j].to(input_.device) |
|
|
if j==0: |
|
|
self.H[j], self.C[j] = cell(input_, (self.H[j],self.C[j])) |
|
|
else: |
|
|
self.H[j], self.C[j] = cell(self.H[j-1],(self.H[j],self.C[j])) |
|
|
return (self.H,self.C) , self.H |
|
|
|
|
|
def initHidden(self,batch_size): |
|
|
self.H, self.C = [],[] |
|
|
for i in range(self.n_layers): |
|
|
self.H.append(torch.zeros( |
|
|
batch_size, self.hidden_dims[i], self.input_shape[0], self.input_shape[1]).to(self.device)) |
|
|
self.C.append(torch.zeros( |
|
|
batch_size, self.hidden_dims[i], self.input_shape[0], self.input_shape[1]).to(self.device)) |
|
|
|
|
|
def setHidden(self, hidden): |
|
|
H,C = hidden |
|
|
self.H, self.C = H,C |
|
|
|
|
|
|
|
|
class dcgan_conv(nn.Module): |
|
|
|
|
|
def __init__(self, nin, nout, stride): |
|
|
super(dcgan_conv, self).__init__() |
|
|
self.main = nn.Sequential( |
|
|
nn.Conv2d(in_channels=nin, out_channels=nout, kernel_size=(3,3), |
|
|
stride=stride, padding=1), |
|
|
nn.GroupNorm(16, nout), |
|
|
nn.LeakyReLU(0.2, inplace=True), |
|
|
) |
|
|
|
|
|
def forward(self, input): |
|
|
return self.main(input) |
|
|
|
|
|
|
|
|
class dcgan_upconv(nn.Module): |
|
|
|
|
|
def __init__(self, nin, nout, stride): |
|
|
super(dcgan_upconv, self).__init__() |
|
|
if stride==2: |
|
|
output_padding = 1 |
|
|
else: |
|
|
output_padding = 0 |
|
|
self.main = nn.Sequential( |
|
|
nn.ConvTranspose2d(in_channels=nin, out_channels=nout, kernel_size=(3,3), |
|
|
stride=stride, padding=1, output_padding=output_padding), |
|
|
nn.GroupNorm(16, nout), |
|
|
nn.LeakyReLU(0.2, inplace=True), |
|
|
) |
|
|
|
|
|
def forward(self, input): |
|
|
return self.main(input) |
|
|
|
|
|
|
|
|
class encoder_E(nn.Module): |
|
|
|
|
|
def __init__(self, nc=1, nf=32, patch_size=4): |
|
|
super(encoder_E, self).__init__() |
|
|
assert patch_size in [2, 4] |
|
|
stride_2 = patch_size // 2 |
|
|
|
|
|
self.c1 = dcgan_conv(nc, nf, stride=2) |
|
|
self.c2 = dcgan_conv(nf, nf, stride=1) |
|
|
self.c3 = dcgan_conv(nf, 2*nf, stride=stride_2) |
|
|
|
|
|
def forward(self, input): |
|
|
h1 = self.c1(input) |
|
|
h2 = self.c2(h1) |
|
|
h3 = self.c3(h2) |
|
|
return h3 |
|
|
|
|
|
|
|
|
class decoder_D(nn.Module): |
|
|
|
|
|
def __init__(self, nc=1, nf=32, patch_size=4): |
|
|
super(decoder_D, self).__init__() |
|
|
assert patch_size in [2, 4] |
|
|
stride_2 = patch_size // 2 |
|
|
output_padding = 1 if stride_2==2 else 0 |
|
|
self.upc1 = dcgan_upconv(2*nf, nf, stride=2) |
|
|
self.upc2 = dcgan_upconv(nf, nf, stride=1) |
|
|
self.upc3 = nn.ConvTranspose2d(in_channels=nf, out_channels=nc, kernel_size=(3,3), |
|
|
stride=stride_2, padding=1, |
|
|
output_padding=output_padding) |
|
|
|
|
|
def forward(self, input): |
|
|
d1 = self.upc1(input) |
|
|
d2 = self.upc2(d1) |
|
|
d3 = self.upc3(d2) |
|
|
return d3 |
|
|
|
|
|
|
|
|
class encoder_specific(nn.Module): |
|
|
|
|
|
def __init__(self, nc=64, nf=64): |
|
|
super(encoder_specific, self).__init__() |
|
|
self.c1 = dcgan_conv(nc, nf, stride=1) |
|
|
self.c2 = dcgan_conv(nf, nf, stride=1) |
|
|
|
|
|
def forward(self, input): |
|
|
h1 = self.c1(input) |
|
|
h2 = self.c2(h1) |
|
|
return h2 |
|
|
|
|
|
|
|
|
class decoder_specific(nn.Module): |
|
|
|
|
|
def __init__(self, nc=64, nf=64): |
|
|
super(decoder_specific, self).__init__() |
|
|
self.upc1 = dcgan_upconv(nf, nf, stride=1) |
|
|
self.upc2 = dcgan_upconv(nf, nc, stride=1) |
|
|
|
|
|
def forward(self, input): |
|
|
d1 = self.upc1(input) |
|
|
d2 = self.upc2(d1) |
|
|
return d2 |
|
|
|
|
|
|
|
|
class PhyD_EncoderRNN(torch.nn.Module): |
|
|
|
|
|
def __init__(self, phycell, convcell, in_channel=1, patch_size=4): |
|
|
super(PhyD_EncoderRNN, self).__init__() |
|
|
self.encoder_E = encoder_E(nc=in_channel, patch_size=patch_size) |
|
|
self.encoder_Ep = encoder_specific() |
|
|
self.encoder_Er = encoder_specific() |
|
|
self.decoder_Dp = decoder_specific() |
|
|
self.decoder_Dr = decoder_specific() |
|
|
self.decoder_D = decoder_D(nc=in_channel, patch_size=patch_size) |
|
|
|
|
|
self.phycell = phycell |
|
|
self.convcell = convcell |
|
|
|
|
|
def forward(self, input, first_timestep=False, decoding=False): |
|
|
input = self.encoder_E(input) |
|
|
|
|
|
if decoding: |
|
|
input_phys = None |
|
|
else: |
|
|
input_phys = self.encoder_Ep(input) |
|
|
input_conv = self.encoder_Er(input) |
|
|
|
|
|
hidden1, output1 = self.phycell(input_phys, first_timestep) |
|
|
hidden2, output2 = self.convcell(input_conv, first_timestep) |
|
|
|
|
|
decoded_Dp = self.decoder_Dp(output1[-1]) |
|
|
decoded_Dr = self.decoder_Dr(output2[-1]) |
|
|
|
|
|
out_phys = torch.sigmoid(self.decoder_D(decoded_Dp)) |
|
|
out_conv = torch.sigmoid(self.decoder_D(decoded_Dr)) |
|
|
|
|
|
concat = decoded_Dp + decoded_Dr |
|
|
output_image = torch.sigmoid( self.decoder_D(concat )) |
|
|
return out_phys, hidden1, output_image, out_phys, out_conv |
|
|
|
|
|
|
|
|
def _apply_axis_left_dot(x, mats): |
|
|
assert x.dim() == len(mats)+1 |
|
|
sizex = x.size() |
|
|
k = x.dim()-1 |
|
|
for i in range(k): |
|
|
x = tensordot(mats[k-i-1], x, dim=[1,k]) |
|
|
x = x.permute([k,]+list(range(k))).contiguous() |
|
|
x = x.view(sizex) |
|
|
return x |
|
|
|
|
|
def _apply_axis_right_dot(x, mats): |
|
|
assert x.dim() == len(mats)+1 |
|
|
sizex = x.size() |
|
|
k = x.dim()-1 |
|
|
x = x.permute(list(range(1,k+1))+[0,]) |
|
|
for i in range(k): |
|
|
x = tensordot(x, mats[i], dim=[0,0]) |
|
|
x = x.contiguous() |
|
|
x = x.view(sizex) |
|
|
return x |
|
|
|
|
|
class _MK(nn.Module): |
|
|
def __init__(self, shape): |
|
|
super(_MK, self).__init__() |
|
|
self._size = torch.Size(shape) |
|
|
self._dim = len(shape) |
|
|
M = [] |
|
|
invM = [] |
|
|
assert len(shape) > 0 |
|
|
j = 0 |
|
|
for l in shape: |
|
|
M.append(zeros((l,l))) |
|
|
for i in range(l): |
|
|
M[-1][i] = ((arange(l)-(l-1)//2)**i)/factorial(i) |
|
|
invM.append(inv(M[-1])) |
|
|
self.register_buffer('_M'+str(j), torch.from_numpy(M[-1])) |
|
|
self.register_buffer('_invM'+str(j), torch.from_numpy(invM[-1])) |
|
|
j += 1 |
|
|
|
|
|
@property |
|
|
def M(self): |
|
|
return list(self._buffers['_M'+str(j)] for j in range(self.dim())) |
|
|
@property |
|
|
def invM(self): |
|
|
return list(self._buffers['_invM'+str(j)] for j in range(self.dim())) |
|
|
|
|
|
def size(self): |
|
|
return self._size |
|
|
def dim(self): |
|
|
return self._dim |
|
|
def _packdim(self, x): |
|
|
assert x.dim() >= self.dim() |
|
|
if x.dim() == self.dim(): |
|
|
x = x[newaxis,:] |
|
|
x = x.contiguous() |
|
|
x = x.view([-1,]+list(x.size()[-self.dim():])) |
|
|
return x |
|
|
|
|
|
def forward(self): |
|
|
pass |
|
|
|
|
|
|
|
|
class M2K(_MK): |
|
|
""" |
|
|
convert moment matrix to convolution kernel |
|
|
Arguments: |
|
|
shape (tuple of int): kernel shape |
|
|
Usage: |
|
|
m2k = M2K([5,5]) |
|
|
m = torch.randn(5,5,dtype=torch.float64) |
|
|
k = m2k(m) |
|
|
""" |
|
|
def __init__(self, shape): |
|
|
super(M2K, self).__init__(shape) |
|
|
def forward(self, m): |
|
|
""" |
|
|
m (Tensor): torch.size=[...,*self.shape] |
|
|
""" |
|
|
sizem = m.size() |
|
|
m = self._packdim(m) |
|
|
m = _apply_axis_left_dot(m, self.invM) |
|
|
m = m.view(sizem) |
|
|
return m |
|
|
|
|
|
|
|
|
class K2M(_MK): |
|
|
""" |
|
|
convert convolution kernel to moment matrix |
|
|
Arguments: |
|
|
shape (tuple of int): kernel shape |
|
|
Usage: |
|
|
k2m = K2M([5,5]) |
|
|
k = torch.randn(5,5,dtype=torch.float64) |
|
|
m = k2m(k) |
|
|
""" |
|
|
def __init__(self, shape): |
|
|
super(K2M, self).__init__(shape) |
|
|
def forward(self, k): |
|
|
""" |
|
|
k (Tensor): torch.size=[...,*self.shape] |
|
|
""" |
|
|
sizek = k.size() |
|
|
k = self._packdim(k) |
|
|
k = _apply_axis_left_dot(k, self.M) |
|
|
k = k.view(sizek) |
|
|
return k |
|
|
|
|
|
|
|
|
def tensordot(a,b,dim): |
|
|
""" |
|
|
tensordot in PyTorch, see numpy.tensordot? |
|
|
""" |
|
|
l = lambda x,y:x*y |
|
|
if isinstance(dim,int): |
|
|
a = a.contiguous() |
|
|
b = b.contiguous() |
|
|
sizea = a.size() |
|
|
sizeb = b.size() |
|
|
sizea0 = sizea[:-dim] |
|
|
sizea1 = sizea[-dim:] |
|
|
sizeb0 = sizeb[:dim] |
|
|
sizeb1 = sizeb[dim:] |
|
|
N = reduce(l, sizea1, 1) |
|
|
assert reduce(l, sizeb0, 1) == N |
|
|
else: |
|
|
adims = dim[0] |
|
|
bdims = dim[1] |
|
|
adims = [adims,] if isinstance(adims, int) else adims |
|
|
bdims = [bdims,] if isinstance(bdims, int) else bdims |
|
|
adims_ = set(range(a.dim())).difference(set(adims)) |
|
|
adims_ = list(adims_) |
|
|
adims_.sort() |
|
|
perma = adims_+adims |
|
|
bdims_ = set(range(b.dim())).difference(set(bdims)) |
|
|
bdims_ = list(bdims_) |
|
|
bdims_.sort() |
|
|
permb = bdims+bdims_ |
|
|
a = a.permute(*perma).contiguous() |
|
|
b = b.permute(*permb).contiguous() |
|
|
|
|
|
sizea = a.size() |
|
|
sizeb = b.size() |
|
|
sizea0 = sizea[:-len(adims)] |
|
|
sizea1 = sizea[-len(adims):] |
|
|
sizeb0 = sizeb[:len(bdims)] |
|
|
sizeb1 = sizeb[len(bdims):] |
|
|
N = reduce(l, sizea1, 1) |
|
|
assert reduce(l, sizeb0, 1) == N |
|
|
a = a.view([-1,N]) |
|
|
b = b.view([N,-1]) |
|
|
c = a@b |
|
|
return c.view(sizea0+sizeb1) |
|
|
|