LARRES / utilpack /phydnet_modules.py
Staty's picture
Upload 50 files
2b21abc verified
import torch
import torch.nn as nn
from numpy import *
from numpy.linalg import *
from scipy.special import factorial
from functools import reduce
__all__ = ['M2K','K2M']
class PhyCell_Cell(nn.Module):
def __init__(self, input_dim, F_hidden_dim, kernel_size, bias=1):
super(PhyCell_Cell, self).__init__()
self.input_dim = input_dim
self.F_hidden_dim = F_hidden_dim
self.kernel_size = kernel_size
self.padding = kernel_size[0] // 2, kernel_size[1] // 2
self.bias = bias
self.F = nn.Sequential()
self.F.add_module('conv1', nn.Conv2d(in_channels=input_dim, out_channels=F_hidden_dim,
kernel_size=self.kernel_size, stride=(1,1), padding=self.padding))
self.F.add_module('bn1',nn.GroupNorm(7 ,F_hidden_dim))
self.F.add_module('conv2', nn.Conv2d(in_channels=F_hidden_dim, out_channels=input_dim,
kernel_size=(1,1), stride=(1,1), padding=(0,0)))
self.convgate = nn.Conv2d(in_channels=self.input_dim + self.input_dim,
out_channels=self.input_dim,
kernel_size=(3,3),
padding=(1,1), bias=self.bias)
def forward(self, x, hidden): # x [batch_size, hidden_dim, height, width]
combined = torch.cat([x, hidden], dim=1) # concatenate along channel axis
combined_conv = self.convgate(combined)
K = torch.sigmoid(combined_conv)
hidden_tilde = hidden + self.F(hidden) # prediction
next_hidden = hidden_tilde + K * (x-hidden_tilde) # correction , Haddamard product
return next_hidden
class PhyCell(nn.Module):
def __init__(self, input_shape, input_dim, F_hidden_dims, n_layers, kernel_size, device):
super(PhyCell, self).__init__()
self.input_shape = input_shape
self.input_dim = input_dim
self.F_hidden_dims = F_hidden_dims
self.n_layers = n_layers
self.kernel_size = kernel_size
self.H = []
self.device = device
cell_list = []
for i in range(0, self.n_layers):
cell_list.append(PhyCell_Cell(input_dim=input_dim,
F_hidden_dim=self.F_hidden_dims[i],
kernel_size=self.kernel_size))
self.cell_list = nn.ModuleList(cell_list)
def forward(self, input_, first_timestep=False): # input_ [batch_size, 1, channels, width, height]
batch_size = input_.data.size()[0]
if (first_timestep):
self.initHidden(batch_size) # init Hidden at each forward start
for j, cell in enumerate(self.cell_list):
self.H[j] = self.H[j].to(input_.device)
if j==0: # bottom layer
self.H[j] = cell(input_, self.H[j])
else:
self.H[j] = cell(self.H[j-1],self.H[j])
return self.H, self.H
def initHidden(self, batch_size):
self.H = []
for i in range(self.n_layers):
self.H.append(torch.zeros(
batch_size, self.input_dim, self.input_shape[0], self.input_shape[1]).to(self.device))
def setHidden(self, H):
self.H = H
class PhyD_ConvLSTM_Cell(nn.Module):
def __init__(self, input_shape, input_dim, hidden_dim, kernel_size, bias=1):
"""
input_shape: (int, int)
Height and width of input tensor as (height, width).
input_dim: int
Number of channels of input tensor.
hidden_dim: int
Number of channels of hidden state.
kernel_size: (int, int)
Size of the convolutional kernel.
bias: bool
Whether or not to add the bias.
"""
super(PhyD_ConvLSTM_Cell, self).__init__()
self.height, self.width = input_shape
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.kernel_size = kernel_size
self.padding = kernel_size[0] // 2, kernel_size[1] // 2
self.bias = bias
self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
out_channels=4 * self.hidden_dim,
kernel_size=self.kernel_size,
padding=self.padding, bias=self.bias)
# we implement LSTM that process only one timestep
def forward(self, x, hidden): # x [batch, hidden_dim, width, height]
h_cur, c_cur = hidden
combined = torch.cat([x, h_cur], dim=1) # concatenate along channel axis
combined_conv = self.conv(combined)
cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
i = torch.sigmoid(cc_i)
f = torch.sigmoid(cc_f)
o = torch.sigmoid(cc_o)
g = torch.tanh(cc_g)
c_next = f * c_cur + i * g
h_next = o * torch.tanh(c_next)
return h_next, c_next
class PhyD_ConvLSTM(nn.Module):
def __init__(self, input_shape, input_dim, hidden_dims, n_layers, kernel_size, device):
super(PhyD_ConvLSTM, self).__init__()
self.input_shape = input_shape
self.input_dim = input_dim
self.hidden_dims = hidden_dims
self.n_layers = n_layers
self.kernel_size = kernel_size
self.H, self.C = [], []
self.device = device
cell_list = []
for i in range(0, self.n_layers):
cur_input_dim = self.input_dim if i == 0 else self.hidden_dims[i-1]
print('layer ', i, 'input dim ', cur_input_dim, ' hidden dim ', self.hidden_dims[i])
cell_list.append(PhyD_ConvLSTM_Cell(input_shape=self.input_shape,
input_dim=cur_input_dim,
hidden_dim=self.hidden_dims[i],
kernel_size=self.kernel_size))
self.cell_list = nn.ModuleList(cell_list)
def forward(self, input_, first_timestep=False): # input_ [batch_size, 1, channels, width, height]
batch_size = input_.data.size()[0]
if (first_timestep):
self.initHidden(batch_size) # init Hidden at each forward start
for j, cell in enumerate(self.cell_list):
self.H[j] = self.H[j].to(input_.device)
self.C[j] = self.C[j].to(input_.device)
if j==0: # bottom layer
self.H[j], self.C[j] = cell(input_, (self.H[j],self.C[j]))
else:
self.H[j], self.C[j] = cell(self.H[j-1],(self.H[j],self.C[j]))
return (self.H,self.C) , self.H # (hidden, output)
def initHidden(self,batch_size):
self.H, self.C = [],[]
for i in range(self.n_layers):
self.H.append(torch.zeros(
batch_size, self.hidden_dims[i], self.input_shape[0], self.input_shape[1]).to(self.device))
self.C.append(torch.zeros(
batch_size, self.hidden_dims[i], self.input_shape[0], self.input_shape[1]).to(self.device))
def setHidden(self, hidden):
H,C = hidden
self.H, self.C = H,C
class dcgan_conv(nn.Module):
def __init__(self, nin, nout, stride):
super(dcgan_conv, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(in_channels=nin, out_channels=nout, kernel_size=(3,3),
stride=stride, padding=1),
nn.GroupNorm(16, nout),
nn.LeakyReLU(0.2, inplace=True),
)
def forward(self, input):
return self.main(input)
class dcgan_upconv(nn.Module):
def __init__(self, nin, nout, stride):
super(dcgan_upconv, self).__init__()
if stride==2:
output_padding = 1
else:
output_padding = 0
self.main = nn.Sequential(
nn.ConvTranspose2d(in_channels=nin, out_channels=nout, kernel_size=(3,3),
stride=stride, padding=1, output_padding=output_padding),
nn.GroupNorm(16, nout),
nn.LeakyReLU(0.2, inplace=True),
)
def forward(self, input):
return self.main(input)
class encoder_E(nn.Module):
def __init__(self, nc=1, nf=32, patch_size=4):
super(encoder_E, self).__init__()
assert patch_size in [2, 4]
stride_2 = patch_size // 2
# input is (1) x 64 x 64
self.c1 = dcgan_conv(nc, nf, stride=2) # (32) x 32 x 32
self.c2 = dcgan_conv(nf, nf, stride=1) # (32) x 32 x 32
self.c3 = dcgan_conv(nf, 2*nf, stride=stride_2) # (64) x 16 x 16
def forward(self, input):
h1 = self.c1(input)
h2 = self.c2(h1)
h3 = self.c3(h2)
return h3
class decoder_D(nn.Module):
def __init__(self, nc=1, nf=32, patch_size=4):
super(decoder_D, self).__init__()
assert patch_size in [2, 4]
stride_2 = patch_size // 2
output_padding = 1 if stride_2==2 else 0
self.upc1 = dcgan_upconv(2*nf, nf, stride=2) #(32) x 32 x 32
self.upc2 = dcgan_upconv(nf, nf, stride=1) #(32) x 32 x 32
self.upc3 = nn.ConvTranspose2d(in_channels=nf, out_channels=nc, kernel_size=(3,3),
stride=stride_2, padding=1,
output_padding=output_padding) #(nc) x 64 x 64
def forward(self, input):
d1 = self.upc1(input)
d2 = self.upc2(d1)
d3 = self.upc3(d2)
return d3
class encoder_specific(nn.Module):
def __init__(self, nc=64, nf=64):
super(encoder_specific, self).__init__()
self.c1 = dcgan_conv(nc, nf, stride=1) # (64) x 16 x 16
self.c2 = dcgan_conv(nf, nf, stride=1) # (64) x 16 x 16
def forward(self, input):
h1 = self.c1(input)
h2 = self.c2(h1)
return h2
class decoder_specific(nn.Module):
def __init__(self, nc=64, nf=64):
super(decoder_specific, self).__init__()
self.upc1 = dcgan_upconv(nf, nf, stride=1) #(64) x 16 x 16
self.upc2 = dcgan_upconv(nf, nc, stride=1) #(32) x 32 x 32
def forward(self, input):
d1 = self.upc1(input)
d2 = self.upc2(d1)
return d2
class PhyD_EncoderRNN(torch.nn.Module):
def __init__(self, phycell, convcell, in_channel=1, patch_size=4):
super(PhyD_EncoderRNN, self).__init__()
self.encoder_E = encoder_E(nc=in_channel, patch_size=patch_size) # general encoder 64x64x1 -> 32x32x32
self.encoder_Ep = encoder_specific() # specific image encoder 32x32x32 -> 16x16x64
self.encoder_Er = encoder_specific()
self.decoder_Dp = decoder_specific() # specific image decoder 16x16x64 -> 32x32x32
self.decoder_Dr = decoder_specific()
self.decoder_D = decoder_D(nc=in_channel, patch_size=patch_size) # general decoder 32x32x32 -> 64x64x1
self.phycell = phycell
self.convcell = convcell
def forward(self, input, first_timestep=False, decoding=False):
input = self.encoder_E(input) # general encoder 64x64x1 -> 32x32x32
if decoding: # input=None in decoding phase
input_phys = None
else:
input_phys = self.encoder_Ep(input)
input_conv = self.encoder_Er(input)
hidden1, output1 = self.phycell(input_phys, first_timestep)
hidden2, output2 = self.convcell(input_conv, first_timestep)
decoded_Dp = self.decoder_Dp(output1[-1])
decoded_Dr = self.decoder_Dr(output2[-1])
out_phys = torch.sigmoid(self.decoder_D(decoded_Dp)) # partial reconstructions for vizualization
out_conv = torch.sigmoid(self.decoder_D(decoded_Dr))
concat = decoded_Dp + decoded_Dr
output_image = torch.sigmoid( self.decoder_D(concat ))
return out_phys, hidden1, output_image, out_phys, out_conv
def _apply_axis_left_dot(x, mats):
assert x.dim() == len(mats)+1
sizex = x.size()
k = x.dim()-1
for i in range(k):
x = tensordot(mats[k-i-1], x, dim=[1,k])
x = x.permute([k,]+list(range(k))).contiguous()
x = x.view(sizex)
return x
def _apply_axis_right_dot(x, mats):
assert x.dim() == len(mats)+1
sizex = x.size()
k = x.dim()-1
x = x.permute(list(range(1,k+1))+[0,])
for i in range(k):
x = tensordot(x, mats[i], dim=[0,0])
x = x.contiguous()
x = x.view(sizex)
return x
class _MK(nn.Module):
def __init__(self, shape):
super(_MK, self).__init__()
self._size = torch.Size(shape)
self._dim = len(shape)
M = []
invM = []
assert len(shape) > 0
j = 0
for l in shape:
M.append(zeros((l,l)))
for i in range(l):
M[-1][i] = ((arange(l)-(l-1)//2)**i)/factorial(i)
invM.append(inv(M[-1]))
self.register_buffer('_M'+str(j), torch.from_numpy(M[-1]))
self.register_buffer('_invM'+str(j), torch.from_numpy(invM[-1]))
j += 1
@property
def M(self):
return list(self._buffers['_M'+str(j)] for j in range(self.dim()))
@property
def invM(self):
return list(self._buffers['_invM'+str(j)] for j in range(self.dim()))
def size(self):
return self._size
def dim(self):
return self._dim
def _packdim(self, x):
assert x.dim() >= self.dim()
if x.dim() == self.dim():
x = x[newaxis,:]
x = x.contiguous()
x = x.view([-1,]+list(x.size()[-self.dim():]))
return x
def forward(self):
pass
class M2K(_MK):
"""
convert moment matrix to convolution kernel
Arguments:
shape (tuple of int): kernel shape
Usage:
m2k = M2K([5,5])
m = torch.randn(5,5,dtype=torch.float64)
k = m2k(m)
"""
def __init__(self, shape):
super(M2K, self).__init__(shape)
def forward(self, m):
"""
m (Tensor): torch.size=[...,*self.shape]
"""
sizem = m.size()
m = self._packdim(m)
m = _apply_axis_left_dot(m, self.invM)
m = m.view(sizem)
return m
class K2M(_MK):
"""
convert convolution kernel to moment matrix
Arguments:
shape (tuple of int): kernel shape
Usage:
k2m = K2M([5,5])
k = torch.randn(5,5,dtype=torch.float64)
m = k2m(k)
"""
def __init__(self, shape):
super(K2M, self).__init__(shape)
def forward(self, k):
"""
k (Tensor): torch.size=[...,*self.shape]
"""
sizek = k.size()
k = self._packdim(k)
k = _apply_axis_left_dot(k, self.M)
k = k.view(sizek)
return k
def tensordot(a,b,dim):
"""
tensordot in PyTorch, see numpy.tensordot?
"""
l = lambda x,y:x*y
if isinstance(dim,int):
a = a.contiguous()
b = b.contiguous()
sizea = a.size()
sizeb = b.size()
sizea0 = sizea[:-dim]
sizea1 = sizea[-dim:]
sizeb0 = sizeb[:dim]
sizeb1 = sizeb[dim:]
N = reduce(l, sizea1, 1)
assert reduce(l, sizeb0, 1) == N
else:
adims = dim[0]
bdims = dim[1]
adims = [adims,] if isinstance(adims, int) else adims
bdims = [bdims,] if isinstance(bdims, int) else bdims
adims_ = set(range(a.dim())).difference(set(adims))
adims_ = list(adims_)
adims_.sort()
perma = adims_+adims
bdims_ = set(range(b.dim())).difference(set(bdims))
bdims_ = list(bdims_)
bdims_.sort()
permb = bdims+bdims_
a = a.permute(*perma).contiguous()
b = b.permute(*permb).contiguous()
sizea = a.size()
sizeb = b.size()
sizea0 = sizea[:-len(adims)]
sizea1 = sizea[-len(adims):]
sizeb0 = sizeb[:len(bdims)]
sizeb1 = sizeb[len(bdims):]
N = reduce(l, sizea1, 1)
assert reduce(l, sizeb0, 1) == N
a = a.view([-1,N])
b = b.view([N,-1])
c = a@b
return c.view(sizea0+sizeb1)