pltnhan07's picture
Add files using upload-large-folder tool
3d7e366 verified
import numpy as np
import torch
import torch.nn.functional as F
from codebase import utils as ut
from torch import autograd, nn, optim
from torch import nn
from torch.nn import functional as F
from torch.nn import Linear
import numpy as np
import torch
import torch.nn.functional as F
from codebase import utils as ut
from torch import autograd, nn, optim
from torch import nn
from torch.nn import functional as F
from torch.nn import Linear
# Gaussian Encoder
class Encoder(nn.Module):
def __init__(self, z_dim, channel=4, y_dim=4):
super().__init__()
self.z_dim = z_dim
self.y_dim = y_dim
self.channel = channel
self.fc1 = nn.Linear(self.channel * 96 * 96, 300)
self.fc2 = nn.Linear(300 + y_dim, 300)
self.fc3 = nn.Linear(300, 300)
self.fc4 = nn.Linear(300, 2 * z_dim)
self.LReLU = nn.LeakyReLU(0.2, inplace=True)
self.net = nn.Sequential(
nn.Linear(self.channel * 96 * 96, 900),
nn.ELU(),
nn.Linear(900, 300),
nn.ELU(),
nn.Linear(300, 2 * z_dim),
)
# self.net = nn.Sequential(
# nn.Linear(self.channel * 96 * 96, 900),
# nn.ELU(),
# nn.Linear(900, 300),
# )
self.fc_mu = nn.Linear(300, z_dim)
self.fc_var = nn.Linear(300, z_dim)
def conditional_encode(self, x, l):
x = x.view(-1, self.channel * 96 * 96)
x = F.elu(self.fc1(x))
l = l.view(-1, 4)
x = F.elu(self.fc2(torch.cat([x, l], dim=1)))
x = F.elu(self.fc3(x))
x = self.fc4(x)
m, v = ut.gaussian_parameters(x, dim=1)
return m, v
def encode(self, x, y=None):
xy = x if y is None else torch.cat((x, y), dim=1)
xy = xy.view(-1, self.channel * 96 * 96)
h = self.net(xy)
# m = self.fc_mu(h)
# v = self.fc_var(h)
m, v = ut.gaussian_parameters(h, dim=1)
return m, v
class Decoder(nn.Module):
def __init__(self, z_dim, y_dim=0):
super().__init__()
self.z_dim = z_dim
self.y_dim = y_dim
self.net = nn.Sequential(
nn.Linear(z_dim + y_dim, 300),
nn.ELU(),
nn.Linear(300, 300),
nn.ELU(),
nn.Linear(300, 4 * 96 * 96)
)
def decode(self, z, y=None):
zy = z if y is None else torch.cat((z, y), dim=1)
return self.net(zy)
class Decoder_DAG(nn.Module):
def __init__(self, z_dim, concept, z1_dim, channel=4, y_dim=0):
super().__init__()
self.z_dim = z_dim
self.z1_dim = z1_dim
self.concept = concept
self.y_dim = y_dim
self.channel = channel
# print(self.channel)
self.elu = nn.ELU()
self.net1 = nn.Sequential(
nn.Linear(z1_dim + y_dim, 300),
nn.ELU(),
nn.Linear(300, 300),
nn.ELU(),
nn.Linear(300, 1024),
nn.ELU(),
nn.Linear(1024, self.channel * 96 * 96)
)
self.net2 = nn.Sequential(
nn.Linear(z1_dim + y_dim, 300),
nn.ELU(),
nn.Linear(300, 300),
nn.ELU(),
nn.Linear(300, 1024),
nn.ELU(),
nn.Linear(1024, self.channel * 96 * 96)
)
self.net3 = nn.Sequential(
nn.Linear(z1_dim + y_dim, 300),
nn.ELU(),
nn.Linear(300, 300),
nn.ELU(),
nn.Linear(300, 1024),
nn.ELU(),
nn.Linear(1024, self.channel * 96 * 96)
)
self.net4 = nn.Sequential(
nn.Linear(z1_dim + y_dim, 300),
nn.ELU(),
nn.Linear(300, 300),
nn.ELU(),
nn.Linear(300, 1024),
nn.ELU(),
nn.Linear(1024, self.channel * 96 * 96)
)
self.net5 = nn.Sequential(
nn.ELU(),
nn.Linear(1024, self.channel * 96 * 96)
)
self.net6 = nn.Sequential(
nn.Linear(z_dim, 300),
nn.ELU(),
nn.Linear(300, 300),
nn.ELU(),
nn.Linear(300, 1024),
nn.ELU(),
nn.Linear(1024, 1024),
nn.ELU(),
nn.Linear(1024, self.channel * 96 * 96)
)
def decode_condition(self, z, u):
# z = z.view(-1,3*4)
z = z.view(-1, 4 * 4)
z1, z2, z3, z4 = torch.split(z, self.z_dim // 4, dim=1)
# print(z1.shape)
# exit(0)
# print(u[:,0].reshape(1,u.size()[0]).size())
rx1 = self.net1(
torch.transpose(torch.cat((torch.transpose(z1, 1, 0), u[:, 0].reshape(1, u.size()[0])), dim=0), 1, 0))
rx2 = self.net2(
torch.transpose(torch.cat((torch.transpose(z2, 1, 0), u[:, 1].reshape(1, u.size()[0])), dim=0), 1, 0))
rx3 = self.net3(
torch.transpose(torch.cat((torch.transpose(z3, 1, 0), u[:, 2].reshape(1, u.size()[0])), dim=0), 1, 0))
rx4 = self.net4(
torch.transpose(torch.cat((torch.transpose(z4, 1, 0), u[:, 2].reshape(1, u.size()[0])), dim=0), 1, 0))
temp = torch.cat((rx1, rx2, rx3, rx4), dim=1)
# print(temp.shape)
# exit(0)
# h = self.net6(torch.cat((rx1, rx2, rx3, rx4), dim=1))
h = (rx1 + rx2 + rx3 + rx4) / 4
return h
def decode_mix(self, z):
z = z.permute(0, 2, 1)
z = torch.sum(z, dim=2, out=None)
# print(z.contiguous().size())
z = z.contiguous()
h = self.net1(z)
return h
def decode_union(self, z, u, y=None):
z = z.view(-1, self.concept * self.z1_dim)
zy = z if y is None else torch.cat((z, y), dim=1)
if self.z1_dim == 1:
zy = zy.reshape(zy.size()[0], zy.size()[1], 1)
zy1, zy2, zy3, zy4 = zy[:, 0], zy[:, 1], zy[:, 2], zy[:, 3]
else:
zy1, zy2, zy3, zy4 = torch.split(zy, self.z_dim // self.concept, dim=1)
rx1 = self.net1(zy1)
rx2 = self.net2(zy2)
rx3 = self.net3(zy3)
rx4 = self.net4(zy4)
h = self.net5((rx1 + rx2 + rx3 + rx4) / 4)
return h, h, h, h, h
def decode(self, z, u, y=None):
z = z.view(-1, self.concept * self.z1_dim)
h = self.net6(z)
return h, h, h, h, h
def decode_sep(self, z, u, y=None):
z = z.view(-1, self.concept * self.z1_dim)
zy = z if y is None else torch.cat((z, y), dim=1)
if self.z1_dim == 1:
zy = zy.reshape(zy.size()[0], zy.size()[1], 1)
if self.concept == 4:
zy1, zy2, zy3, zy4 = zy[:, 0], zy[:, 1], zy[:, 2], zy[:, 3]
elif self.concept == 3:
zy1, zy2, zy3 = zy[:, 0], zy[:, 1], zy[:, 2]
else:
if self.concept == 4:
zy1, zy2, zy3, zy4 = torch.split(zy, self.z_dim // self.concept, dim=1)
elif self.concept == 3:
zy1, zy2, zy3 = torch.split(zy, self.z_dim // self.concept, dim=1)
rx1 = self.net1(zy1)
rx2 = self.net2(zy2)
rx3 = self.net3(zy3)
if self.concept == 4:
rx4 = self.net4(zy4)
h = (rx1 + rx2 + rx3 + rx4) / self.concept
elif self.concept == 3:
h = (rx1 + rx2 + rx3) / self.concept
return h, h, h, h, h
def decode_cat(self, z, u, y=None):
z = z.view(-1, 4 * 4)
zy = z if y is None else torch.cat((z, y), dim=1)
zy1, zy2, zy3, zy4 = torch.split(zy, 1, dim=1)
rx1 = self.net1(zy1)
rx2 = self.net2(zy2)
rx3 = self.net3(zy3)
rx4 = self.net4(zy4)
h = self.net5(torch.cat((rx1, rx2, rx3, rx4), dim=1))
return h
class F_SCM(nn.Module):
def __init__(self, latent_dim=4, f_dim=4):
super().__init__()
self.latent_dim = latent_dim
self.f_dim = f_dim
self.net = nn.Sequential(
nn.Linear(self.latent_dim*self.f_dim, 32),
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, f_dim)
)
def forward(self, z, z_int, mask, I=None):
# print(z.shape)
# exit(0)
z_masked = torch.empty(z.size()).to(device)
for i in range(4):
# if 1 in mask[:, i]:
# eps = torch.normal(mean=torch.zeros(self.f_dim), std=torch.ones(self.f_dim)).to(device)
# z_masked[:, i] = self.net((z * mask[:, i]).reshape(-1, self.latent_dim*self.f_dim)) + eps
# else:
# z_masked[:, i] = z[:, i]
if I is not None:
for j in range(z.shape[0]):
if I[j][0][i] == 0:
z_masked[j, i] = z[j, i]
else:
z_masked[j, i] = z_int[j, i]
if 1 in mask[:, i]:
eps = torch.normal(mean=torch.zeros(self.f_dim), std=torch.ones(self.f_dim)).to(device)
z_masked[:, i] = self.net((z * mask[:, i]).reshape(-1, self.latent_dim*self.f_dim)) + eps
else:
z_masked[:, i] = z[:, i]
# exit(0)
# mean, std, var = torch.mean(z_masked), torch.std(z_masked), torch.var(z_masked)
return z_masked
class MaskLayer(nn.Module):
def __init__(self, z_dim, concept=4, z1_dim=4):
super().__init__()
self.z_dim = z_dim
self.z1_dim = z1_dim
self.concept = concept
self.elu = nn.ELU()
self.net1 = nn.Sequential(
nn.Linear(z1_dim, 32),
nn.ELU(),
nn.Linear(32, z1_dim),
)
self.net2 = nn.Sequential(
nn.Linear(z1_dim, 32),
nn.ELU(),
nn.Linear(32, z1_dim),
)
self.net3 = nn.Sequential(
nn.Linear(z1_dim, 32),
nn.ELU(),
nn.Linear(32, z1_dim),
)
self.net4 = nn.Sequential(
nn.Linear(z1_dim, 32),
nn.ELU(),
nn.Linear(32, z1_dim)
)
self.net = nn.Sequential(
nn.Linear(z_dim, 32),
nn.ELU(),
nn.Linear(32, z_dim),
)
def masked(self, z):
z = z.view(-1, self.z_dim)
z = self.net(z)
return z
def masked_sep(self, z):
z = z.view(-1, self.z_dim)
z = self.net(z)
return z
def mix(self, z):
zy = z.view(-1, self.concept * self.z1_dim)
if self.z1_dim == 1:
zy = zy.reshape(zy.size()[0], zy.size()[1], 1)
if self.concept == 4:
zy1, zy2, zy3, zy4 = zy[:, 0], zy[:, 1], zy[:, 2], zy[:, 3]
elif self.concept == 3:
zy1, zy2, zy3 = zy[:, 0], zy[:, 1], zy[:, 2]
else:
if self.concept == 4:
#print(zy.shape)
#print(len(torch.split(zy, self.z_dim // self.concept, dim=1)))
zy1, zy2, zy3, zy4 = torch.split(zy, self.z_dim // self.concept, dim=1)
elif self.concept == 3:
zy1, zy2, zy3 = torch.split(zy, self.z_dim // self.concept, dim=1)
rx1 = self.net1(zy1)
rx2 = self.net2(zy2)
rx3 = self.net3(zy3)
if self.concept == 4:
rx4 = self.net4(zy4)
h = torch.cat((rx1, rx2, rx3, rx4), dim=1)
elif self.concept == 3:
h = torch.cat((rx1, rx2, rx3), dim=1)
# print(h.size())
return h
class MaskLayer1(nn.Module):
def __init__(self, z_dim, concept=4, z1_dim=4):
super().__init__()
self.z_dim = z_dim
self.z1_dim = z1_dim
self.concept = concept
self.elu = nn.ELU()
self.net1 = nn.Sequential(
nn.Linear(z1_dim, 32),
nn.ELU(),
nn.Linear(32, z1_dim),
)
self.net2 = nn.Sequential(
nn.Linear(z1_dim, 32),
nn.ELU(),
nn.Linear(32, z1_dim),
)
self.net3 = nn.Sequential(
nn.Linear(z1_dim, 32),
nn.ELU(),
nn.Linear(32, z1_dim),
)
self.net4 = nn.Sequential(
nn.Linear(z1_dim, 32),
nn.ELU(),
nn.Linear(32, z1_dim)
)
self.net = nn.Sequential(
nn.Linear(z_dim, 32),
nn.ELU(),
nn.Linear(32, z_dim),
)
self.net_g = nn.Sequential(
nn.Linear(z_dim, 32),
nn.ELU(),
nn.Linear(32, z1_dim),
)
def masked(self, z):
z = z.view(-1, self.z_dim)
z = self.net(z)
return z
def masked_sep(self, z):
z = z.view(-1, self.z_dim)
z = self.net(z)
return z
def g(self, z, i=None):
# z = z[:, :8]
# print(z.shape)
# exit(0)
rx = self.net_g(z)
# print(rx.shape)
# exit(0)
return rx
def mix(self, z):
zy = z.view(-1, self.concept * self.z1_dim)
if self.z1_dim == 1:
zy = zy.reshape(zy.size()[0], zy.size()[1], 1)
if self.concept == 4:
zy1, zy2, zy3, zy4 = zy[:, 0], zy[:, 1], zy[:, 2], zy[:, 3]
elif self.concept == 3:
zy1, zy2, zy3 = zy[:, 0], zy[:, 1], zy[:, 2]
else:
if self.concept == 4:
#print(zy.shape)
#print(len(torch.split(zy, self.z_dim // self.concept, dim=1)))
zy1, zy2, zy3, zy4 = torch.split(zy, self.z_dim // self.concept, dim=1)
elif self.concept == 3:
zy1, zy2, zy3 = torch.split(zy, self.z_dim // self.concept, dim=1)
rx1 = self.net1(zy1)
rx2 = self.net2(zy2)
rx3 = self.net3(zy3)
if self.concept == 4:
rx4 = self.net4(zy4)
h = torch.cat((rx1, rx2, rx3, rx4), dim=1)
elif self.concept == 3:
h = torch.cat((rx1, rx2, rx3), dim=1)
# print(h.size())
return h
class Mix(nn.Module):
def __init__(self, z_dim, concept, z1_dim):
super().__init__()
self.z_dim = z_dim
self.z1_dim = z1_dim
self.concept = concept
self.elu = nn.ELU()
self.net1 = nn.Sequential(
nn.Linear(z1_dim, 16),
nn.ELU(),
nn.Linear(16, z1_dim),
)
self.net2 = nn.Sequential(
nn.Linear(z1_dim, 16),
nn.ELU(),
nn.Linear(16, z1_dim),
)
self.net3 = nn.Sequential(
nn.Linear(z1_dim, 16),
nn.ELU(),
nn.Linear(16, z1_dim),
)
self.net4 = nn.Sequential(
nn.Linear(z1_dim, 16),
nn.ELU(),
nn.Linear(16, z1_dim),
)
def mix(self, z):
zy = z.view(-1, self.concept * self.z1_dim)
if self.z1_dim == 1:
zy = zy.reshape(zy.size()[0], zy.size()[1], 1)
zy1, zy2, zy3, zy4 = zy[:, 0], zy[:, 1], zy[:, 2], zy[:, 3]
else:
zy1, zy2, zy3, zy4 = torch.split(zy, self.z_dim // self.concept, dim=1)
rx1 = self.net1(zy1)
rx2 = self.net2(zy2)
rx3 = self.net3(zy3)
rx4 = self.net4(zy4)
h = torch.cat((rx1, rx2, rx3, rx4), dim=1)
# print(h.size())
return h
class CausalLayer(nn.Module):
def __init__(self, z_dim, concept=4, z1_dim=4):
super().__init__()
self.z_dim = z_dim
self.z1_dim = z1_dim
self.concept = concept
self.elu = nn.ELU()
self.net1 = nn.Sequential(
nn.Linear(z1_dim, 32),
nn.ELU(),
nn.Linear(32, z1_dim),
)
self.net2 = nn.Sequential(
nn.Linear(z1_dim, 32),
nn.ELU(),
nn.Linear(32, z1_dim),
)
self.net3 = nn.Sequential(
nn.Linear(z1_dim, 32),
nn.ELU(),
nn.Linear(32, z1_dim),
)
self.net4 = nn.Sequential(
nn.Linear(z1_dim, 32),
nn.ELU(),
nn.Linear(32, z1_dim)
)
self.net = nn.Sequential(
nn.Linear(z_dim, 128),
nn.ELU(),
nn.Linear(128, z_dim),
)
def calculate(self, z, v):
z = z.view(-1, self.z_dim)
z = self.net(z)
return z, v
def masked_sep(self, z, v):
z = z.view(-1, self.z_dim)
z = self.net(z)
return z, v
def calculate_dag(self, z, v):
zy = z.view(-1, self.concept * self.z1_dim)
if self.z1_dim == 1:
zy = zy.reshape(zy.size()[0], zy.size()[1], 1)
zy1, zy2, zy3, zy4 = zy[:, 0], zy[:, 1], zy[:, 2], zy[:, 3]
else:
zy1, zy2, zy3, zy4 = torch.split(zy, self.z_dim // self.concept, dim=1)
rx1 = self.net1(zy1)
rx2 = self.net2(zy2)
rx3 = self.net3(zy3)
rx4 = self.net4(zy4)
h = torch.cat((rx1, rx2, rx3, rx4), dim=1)
# print(h.size())
return h, v
class Attention(nn.Module):
def __init__(self, in_features, bias=False):
super().__init__()
self.M = nn.Parameter(torch.nn.init.normal_(torch.zeros(in_features, in_features), mean=0, std=1))
self.sigmd = torch.nn.Sigmoid()
# self.M = nn.Parameter(torch.zeros(in_features,in_features))
# self.A = torch.zeros(in_features,in_features).to(device)
def attention(self, z, e):
a = z.matmul(self.M).matmul(e.permute(0, 2, 1))
a = self.sigmd(a)
# print(self.M)
A = torch.softmax(a, dim=1)
e = torch.matmul(A, e)
return e, A
class DagLayer(nn.Linear):
def __init__(self, in_features, out_features, A=None, i=False, bias=False):
super(Linear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.i = i
self.a = torch.zeros(out_features, out_features)
self.a = self.a
# self.a[0][1], self.a[0][2], self.a[0][3] = 1, 1, 1
# self.a[1][2], self.a[1][3] = 1, 1
# self.a[0, 2], self.a[1, 2], self.a[1, 3], self.a[3, 2] = 1, 1, 1, 1
# self.a[0, 2:4] = 1
# self.a[1, 2:4] = 1
self.a = A
# self.a[0, 1], self.a[1, 2], self.a[3, 2] = 1, 1, 1
# self.a[0, 2], self.a[1, 3], self.a[2, 3] = 1, 1, 1
# self.A = nn.Parameter(self.a)
self.A = self.a#.to(device)
self.b = torch.eye(out_features)
self.b = self.b
self.B = self.b#.to(device)
# self.B = nn.Parameter(self.b)
self.I = torch.eye((out_features))#.to(device)
# self.I = nn.Parameter(torch.eye(out_features))
# self.I.requires_grad = False
if bias:
self.bias = Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
def mask_z(self, x, i):
self.B = self.A.to(x.device)
# x = torch.mul(self.B[:, i].reshape(4, 1).clone(), x.clone())
x = torch.mul((self.B + self.I.to(x.device))[:, i].reshape(4, 1).clone(), x.clone())
# x = torch.matmul(self.B.t().clone(), x.clone())
# print(x.shape)
# x = torch.matmul(self.B.t(), x)
return x
def mask_z_orig(self,x):
self.B = self.A.to(x.device)
#if self.i:
# x = x.view(-1, x.size()[1], 1)
# x = torch.matmul((self.B+0.5).t().int().float(), x)
# return x
x = torch.matmul(self.B.t().float(), x)
return x
def mask_z_learn(self, x, i):
self.B = self.A.to(x.device)
# x = F.linear(x.clone(), (self.B + self.I)[:, i].reshape(4, 1).clone(), self.bias)
x = torch.mul((self.B + self.I.to(x.device))[:, i].reshape(4, 1).clone(), x.clone())
return x
def mask_u(self, x):
self.B = self.A.to(x.device)
# if self.i:
# x = x.view(-1, x.size()[1], 1)
# x = torch.matmul((self.B+0.5).t().int().float(), x)
# return x
x = x.view(-1, x.size()[1], 1)
x = torch.matmul(self.B.t(), x)
return x
def inv_cal(self, x, v):
if x.dim() > 2:
x = x.permute(0, 2, 1)
x = F.linear(x, self.I - self.A, self.bias)
if x.dim() > 2:
x = x.permute(0, 2, 1).contiguous()
return x, v
def calculate_dag(self, x, v):
# print(self.A)
# x = F.linear(x, torch.inverse((torch.abs(self.A))+self.I), self.bias)
self.A = self.A.to(x.device)
if x.dim() > 2:
x = x.permute(0, 2, 1)
x = F.linear(x, torch.inverse(self.I - self.A.t()), self.bias)
# print(x.size())
if x.dim() > 2:
x = x.permute(0, 2, 1).contiguous()
return x, v
def calculate_cov(self, x, v):
# print(self.A)
v = ut.vector_expand(v)
# x = F.linear(x, torch.inverse((torch.abs(self.A))+self.I), self.bias)
x = dag_left_linear(x, torch.inverse(self.I - self.A), self.bias)
v = dag_left_linear(v, torch.inverse(self.I - self.A), self.bias)
v = dag_right_linear(v, torch.inverse(self.I - self.A), self.bias)
# print(v)
return x, v
def calculate_gaussian_ini(self, x, v):
print(self.A)
# x = F.linear(x, torch.inverse((torch.abs(self.A))+self.I), self.bias)
if x.dim() > 2:
x = x.permute(0, 2, 1)
v = v.permute(0, 2, 1)
x = F.linear(x, torch.inverse(self.I - self.A), self.bias)
v = F.linear(v, torch.mul(torch.inverse(self.I - self.A), torch.inverse(self.I - self.A)), self.bias)
if x.dim() > 2:
x = x.permute(0, 2, 1).contiguous()
v = v.permute(0, 2, 1).contiguous()
return x, v
# def encode_
def forward(self, x):
# x = x * torch.inverse((self.A) + self.I)
if x.dim() > 2:
x = x.permute(0, 2, 1)
x = torch.matmul(x, torch.inverse(self.I.to(x.device) - self.A.t().to(x.device)).t())
if x.dim() > 2:
x = x.permute(0, 2, 1).contiguous()
return x
def calculate_gaussian(self, x, v):
print(self.A)
# x = F.linear(x, torch.inverse((torch.abs(self.A))+self.I), self.bias)
if x.dim() > 2:
x = x.permute(0, 2, 1)
v = v.permute(0, 2, 1)
x = dag_left_linear(x, torch.inverse(self.I - self.A), self.bias)
v = dag_left_linear(v, torch.inverse(self.I - self.A), self.bias)
v = dag_right_linear(v, torch.inverse(self.I - self.A), self.bias)
if x.dim() > 2:
x = x.permute(0, 2, 1).contiguous()
v = v.permute(0, 2, 1).contiguous()
return x, v
class DagLayerOrig(nn.Linear):
def __init__(self, in_features, out_features,i = False, bias=False):
super(Linear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.i = i
# self.a = 0.5*torch.ones(out_features,out_features)
self.a = torch.zeros(out_features,out_features)
self.a = self.a
#self.a[0][1], self.a[0][2], self.a[0][3] = 1,1,1
#self.a[1][2], self.a[1][3] = 1,1
self.a[0, 2:4], self.a[1, 2:4] = 1, 1
# self.a[0, 2], self.a[1, 2], self.a[1, 3], self.a[3, 2] = 1, 1, 1, 1
# self.a[0, 1], self.a[1, 2], self.a[3, 2] = 1, 1, 1
# self.a[0, 2], self.a[2, 3], self.a[1, 3] = 1, 1, 1
# self.A = nn.Parameter(self.a)
self.A = self.a#.to(device)
self.b = torch.eye(out_features)
self.b = self.b
# self.B = nn.Parameter(self.b)
self.B = self.b#.to(device)
self.I = torch.eye(out_features)#.to(device)
# self.I = nn.Parameter(torch.eye(out_features))
# self.I.requires_grad=False
if bias:
self.bias = Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
def mask_z(self,x):
self.B = self.A.to(x.device)
#if self.i:
# x = x.view(-1, x.size()[1], 1)
# x = torch.matmul((self.B+0.5).t().int().float(), x)
# return x
x = torch.matmul(self.B.t(), x)
return x
def mask_u(self,x):
self.B = self.A.to(x.device)
#if self.i:
# x = x.view(-1, x.size()[1], 1)
# x = torch.matmul((self.B+0.5).t().int().float(), x)
# return x
x = x.view(-1, x.size()[1], 1)
x = torch.matmul(self.B.t(), x)
return x
def inv_cal(self, x,v):
if x.dim()>2:
x = x.permute(0,2,1)
x = F.linear(x, self.I - self.A, self.bias)
if x.dim()>2:
x = x.permute(0,2,1).contiguous()
return x,v
def calculate_dag(self, x):
#print(self.A)
#x = F.linear(x, torch.inverse((torch.abs(self.A))+self.I), self.bias)
if x.dim()>2:
x = x.permute(0,2,1)
x = F.linear(x, torch.inverse(self.I.to(x.device) - self.A.t().to(x.device)), self.bias)
#print(x.size())
if x.dim()>2:
x = x.permute(0,2,1).contiguous()
return x
def calculate_cov(self, x, v):
#print(self.A)
v = ut.vector_expand(v)
#x = F.linear(x, torch.inverse((torch.abs(self.A))+self.I), self.bias)
x = dag_left_linear(x, torch.inverse(self.I - self.A), self.bias)
v = dag_left_linear(v, torch.inverse(self.I - self.A), self.bias)
v = dag_right_linear(v, torch.inverse(self.I - self.A), self.bias)
#print(v)
return x, v
def calculate_gaussian_ini(self, x, v):
print(self.A)
#x = F.linear(x, torch.inverse((torch.abs(self.A))+self.I), self.bias)
if x.dim()>2:
x = x.permute(0,2,1)
v = v.permute(0,2,1)
x = F.linear(x, torch.inverse(self.I - self.A), self.bias)
v = F.linear(v, torch.mul(torch.inverse(self.I - self.A),torch.inverse(self.I - self.A)), self.bias)
if x.dim()>2:
x = x.permute(0,2,1).contiguous()
v = v.permute(0,2,1).contiguous()
return x, v
#def encode_
def forward(self, x):
# x = x * torch.inverse((self.A)+self.I)
if x.dim() > 2:
x = x.permute(0, 2, 1)
x = torch.matmul(x, torch.inverse(self.I.to(x.device) - self.A.t().to(x.device)).t())
if x.dim() > 2:
x = x.permute(0, 2, 1).contiguous()
return x
def calculate_gaussian(self, x, v):
print(self.A)
#x = F.linear(x, torch.inverse((torch.abs(self.A))+self.I), self.bias)
if x.dim()>2:
x = x.permute(0,2,1)
v = v.permute(0,2,1)
x = dag_left_linear(x, torch.inverse(self.I - self.A), self.bias)
v = dag_left_linear(v, torch.inverse(self.I - self.A), self.bias)
v = dag_right_linear(v, torch.inverse(self.I - self.A), self.bias)
if x.dim()>2:
x = x.permute(0,2,1).contiguous()
v = v.permute(0,2,1).contiguous()
return x, v
# [(in - k + 2p)/s] + 1
class ConvEncoder(nn.Module):
def __init__(self, out_dim=None):
super().__init__()
# init 128*128
# 64x64x32, 32x32x64, 16x16x64, 8x8x64, 4x4x256, 4x4x3 (want
# init 96*96
self.conv1 = torch.nn.Conv2d(3, 32, 4, 2, 1) # 48*48
self.conv2 = torch.nn.Conv2d(32, 64, 4, 2, 1, bias=False) # 24*24
self.conv3 = torch.nn.Conv2d(64, 1, 4, 2, 1, bias=False)
# self.conv4 = torch.nn.Conv2d(128, 1, 1, 1, 0) # 54*44
self.LReLU = torch.nn.LeakyReLU(0.2, inplace=True)
self.convm = torch.nn.Conv2d(1, 1, 4, 2, 1)
self.convv = torch.nn.Conv2d(1, 1, 4, 2, 1)
self.mean_layer = nn.Sequential(
torch.nn.Linear(8 * 8, out_dim)
) # 12*12
self.var_layer = nn.Sequential(
torch.nn.Linear(8 * 8, out_dim)
)
# self.fc1 = torch.nn.Linear(6*6*128, 512)
self.conv6 = nn.Sequential(
nn.Conv2d(3, 32, 4, 2, 1),
nn.ReLU(True),
nn.Conv2d(32, 64, 4, 2, 1),
nn.ReLU(True),
nn.Conv2d(64, 64, 4, 2, 1),
nn.ReLU(True),
nn.Conv2d(64, 64, 4, 2, 1),
nn.ReLU(True),
nn.Conv2d(64, 256, 4, 2, 1), # 4x4x256
nn.ReLU(True),
nn.Conv2d(256, 64, 4, 2, 1) # 2x2x64
)
def encode(self, x):
x = self.LReLU(self.conv1(x))
x = self.LReLU(self.conv2(x))
x = self.LReLU(self.conv3(x))
# x = self.LReLU(self.conv4(x))
# print(x.size())
hm = self.convm(x)
# print(hm.size())
hm = hm.view(-1, 8 * 8)
hv = self.convv(x)
hv = hv.view(-1, 8 * 8)
mu, var = self.mean_layer(hm), self.var_layer(hv)
var = F.softplus(var) + 1e-8
# var = torch.reshape(var, [-1, 16, 16])
# print(mu.size())
return mu, var
def encode_simple(self, x):
x = self.conv6(x)
x = x.reshape(x.shape[0], 256)
# print(x.shape)
# exit(0)
m, v = ut.gaussian_parameters(x, dim=1)
# print(m.size())
return m, v
# [(in - k + 2p)/s] + 1 = out
class ConvDecoder(nn.Module):
def __init__(self, z2_dim):
super().__init__()
self.z2_dim = z2_dim
self.net6 = nn.Sequential(
nn.Conv2d(self.z2_dim, 128, 1), # 1-1+0 / 1 = 1x1x128
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(128, 64, 4), # (128 - 1)
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(64, 64, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(64, 32, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(32, 32, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(32, 32, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(32, 3, 4, 2, 1)
)
def decode_sep(self, x):
return None
def decode(self, z):
z = z.view(-1, self.z2_dim, 1, 1)
z = self.net6(z)
return z
class ConvDec(nn.Module):
def __init__(self, concept, z1_dim, z_dim):
super().__init__()
self.concept = concept
self.z1_dim = z1_dim
self.z_dim = z_dim
self.net1 = ConvDecoder(z1_dim)
self.net2 = ConvDecoder(z1_dim)
self.net3 = ConvDecoder(z1_dim)
self.net4 = ConvDecoder(z1_dim)
self.net5 = nn.Sequential(
nn.Linear(16, 512),
nn.BatchNorm1d(512),
nn.Linear(512, 1024),
nn.BatchNorm1d(1024)
)
self.net6 = nn.Sequential(
nn.Conv2d(16, 128, 1),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(128, 64, 4),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(64, 64, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(64, 32, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(32, 32, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(32, 32, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(32, 3, 4, 2, 1)
)
def decode_sep(self, z, u=None, y=None):
z = z.view(-1, self.concept * self.z1_dim)
zy = z if y is None else torch.cat((z, y), dim=1)
zy1, zy2, zy3, zy4 = torch.split(zy, self.z_dim // self.concept, dim=1)
rx1 = self.net1.decode(zy1)
# print(rx1.size())
rx2 = self.net2.decode(zy2)
rx3 = self.net3.decode(zy3)
rx4 = self.net4.decode(zy4)
z = (rx1 + rx2 + rx3 + rx4) / 4
return z, z, z, z, z
def decode(self, z, u=None, y=None):
z = z.view(-1, self.concept * self.z1_dim, 1, 1)
z = self.net6(z)
# print(z.size())
return z
#############################################################################################################
############################################# CELEBA ENC/DEC ################################################
#############################################################################################################
class CelebAConvEncoder(nn.Module):
def __init__(self, latent_dim, in_channels=3, out_dim=None):
super().__init__()
self.latent_dim = latent_dim
# 128x128
self.conv = nn.Sequential(
nn.Conv2d(in_channels, 32, 4, 2, 1), # 64x64x32
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(32, 64, 4, 2, 1), # 32x32x64
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 64, 4, 2, 1), # 16x16x64
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, 4, 2, 1), # 8x8x128
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 128, 4, 2, 1), # 4x4x128
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, 4, 1), # 1x1x256
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, 1) # 1x1x512
)
modules = []
hidden_dims = [32, 64, 128, 256, 512, 512, 512]
# in: 128x128x3, out: 64x64x32
# in: 64x64x32, out: 32x32x64
# out: 16x16x128
# out: 8x8x256
# out: 4x4x512
# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU(0.2, inplace=True))
)
in_channels = h_dim
self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1], latent_dim)
def encode(self, x):
z = self.conv(x)
z = z.view(-1, 512)
# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.fc_mu(z)
var = self.fc_var(z)
var = F.softplus(var) + 1e-8
return mu, var
class CelebAConvDecoder(nn.Module):
def __init__(self, latent_dim, out_channels=3, out_dim=None):
super().__init__()
self.latent_dim = latent_dim
self.convT = nn.Sequential(
nn.Conv2d(latent_dim, 512, 1), # 1x1x512
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose2d(512, 256, 4), # 4x4x256
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose2d(256, 128, 4, 2, 1), # 8x8x128
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose2d(128, 128, 4, 2, 1), # 16x16x128
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose2d(128, 64, 4, 2, 1), # 32x32x64
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose2d(64, 64, 4, 2, 1), # 64x64x64
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose2d(64, 32, 4, 2, 1), # 128x128x32
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose2d(32, out_channels, 1), # 128x128x3
)
modules = []
hidden_dims = [32, 64, 128, 256, 512, 512, 512]
# in: 128x128x3, out: 64x64x32
# in: 64x64x32, out: 32x32x64
# out: 16x16x128
# out: 8x8x256
# out: 4x4x512
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1])
hidden_dims.reverse()
# in: 4x4x512, out: 8x8x256
# in: 8x8x256, out: 16x16x128
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=4,
stride=2,
padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU(0.2, inplace=True))
)
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=4,
stride=2,
padding=1, output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(hidden_dims[-1], out_channels=out_channels,
kernel_size=4, padding=1),
nn.Tanh())
def decode(self, z):
z = z.view(-1, self.latent_dim, 1, 1)
z = self.convT(z)
return z
# def decode(self, z):
# z = self.decoder_input(z)
# z = z.view(-1, 512, 1, 1)
# x = self.decoder(z)
# x = self.final_layer(x)
#
# return x
class CelebAConvDec(nn.Module):
def __init__(self, latent_dim, out_dim=None):
super().__init__()
self.concept = 4
self.z_dim = latent_dim
self.z1_dim = self.z_dim // self.concept
self.net1 = CelebAConvDecoder(self.z1_dim)
self.net2 = CelebAConvDecoder(self.z1_dim)
self.net3 = CelebAConvDecoder(self.z1_dim)
self.net4 = CelebAConvDecoder(self.z1_dim)
def decode_sep(self, z, u=None, y=None):
z = z.view(-1, self.concept * self.z1_dim) # 16x64
zy = z if y is None else torch.cat((z, y), dim=1)
# print(zy.shape)
zy1, zy2, zy3, zy4 = torch.split(zy, self.z_dim // self.concept, dim=1) # each is 16x16
rx1 = self.net1.decode(zy1)
# print(f"Hi: {rx1.size()}")
rx2 = self.net2.decode(zy2)
rx3 = self.net3.decode(zy3)
rx4 = self.net4.decode(zy4)
# z = torch.cat((rx1, rx2, rx3, rx4), dim=0)
z = (rx1+rx2+rx3+rx4)/4
# print(z.shape)
# sys.exit(0)
return z, z, z, z, z
class ConvEncoderPend(nn.Module):
def __init__(self, latent_dim, in_channel=3, out_dim=None):
super().__init__()
# init 96*96
self.conv1 = torch.nn.Conv2d(in_channel, 24, 4, 2, 1) # 48*48
self.conv2 = torch.nn.Conv2d(24, 48, 4, 2, 1, bias=False) # 24*24
self.conv3 = torch.nn.Conv2d(48, 1, 4, 2, 1, bias=False) # 12x12
self.conv4 = torch.nn.Conv2d(1, 1, 3, 1, bias=False)
# self.conv4 = torch.nn.Conv2d(128, 1, 1, 1, 0) # 54*44
self.LReLU = torch.nn.LeakyReLU(0.2, inplace=True)
self.convm = torch.nn.Conv2d(1, 1, 3, 1) # 6x6 - BUT, changed padding from 1 to 3 in order to make this 8x8
self.convv = torch.nn.Conv2d(1, 1, 3, 1) # 6x6
self.mean_layer = nn.Sequential(
torch.nn.Linear(8*8, latent_dim)
) # 12*12
self.var_layer = nn.Sequential(
torch.nn.Linear(8*8, latent_dim)
)
# self.conv = nn.Sequential(
# nn.Conv2d(3, 32, 4, 2, 1), # 48x48
# nn.ReLU(True),
# nn.Conv2d(32, 32, 4, 2, 1), # 24x24
# nn.ReLU(True),
# nn.Conv2d(32, 64, 4, 2, 1), # 12x12
# nn.ReLU(True),
# nn.Conv2d(64, 64, 4, 2, 1), # 6x6
# nn.ReLU(True),
# nn.Conv2d(64, 64, 4, 2, 1), # 3x3
# nn.ReLU(True),
# nn.Conv2d(64, 256, 4, 1), # 2x2
# nn.ReLU(True),
# nn.Conv2d(256, 128, 1) # 2x2
# )
self.conv = nn.Sequential(
nn.Conv2d(in_channel, 24, 4, 2, 1), # 48x48x24
nn.LeakyReLU(0.2),
nn.Conv2d(24, 24, 4, 2, 1), # 24x24x24
nn.LeakyReLU(0.2),
nn.Conv2d(24, 48, 4, 2, 1), # 12x12x48
nn.LeakyReLU(0.2),
nn.Conv2d(48, 48, 4, 2, 1), # 6x6x48
nn.LeakyReLU(0.2),
nn.Conv2d(48, 48, 4, 2, 1), # 3x3x48
nn.LeakyReLU(0.2),
nn.Conv2d(48, 96, 3, 1), # 3x3x48
nn.LeakyReLU(0.2),
nn.Conv2d(96, latent_dim*2, 4, 2, 2) # 1x1x32
)
def encode(self, x):
# print(x.shape)
# sys.exit(0)
x = self.LReLU(self.conv1(x))
x = self.LReLU(self.conv2(x))
x = self.LReLU(self.conv3(x))
x = self.LReLU(self.conv4(x))
# x = self.LReLU(self.conv4(x))
# print(x.size())
hm = self.convm(x)
# print(hm.size())
hm = hm.view(-1, 8 * 8)
# print(hm.size())
hv = self.convv(x)
hv = hv.view(-1, 8 * 8)
# print(hm.shape)
# sys.exit(0)
mu, var = self.mean_layer(hm), self.var_layer(hv)
var = F.softplus(var) + 1e-8
# var = torch.reshape(var, [-1, 16, 16])
# print(mu.size())
return mu, var
def encode_simple(self, x):
x = self.conv(x)
# print(x.shape)
# sys.exit(0)
# x = x.view(-1, 96)
# x = self.fc(x)
# print(x.shape)
m, v = ut.gaussian_parameters(x, dim=1)
return m, v
# CONVOLUTIONAL DECODER LAYER
class ConvDecoderPend(nn.Module):
def __init__(self, latent_dim, channels=3, out_dim=None):
super().__init__()
self.net6 = nn.Sequential(
nn.Conv2d(latent_dim, 96, 1), # 1x1x96
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(96, 48, 4), # 3x3x48
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(48, 48, 2, 2, 1), # 6x6
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(48, 24, 4, 2, 1), # 12x12
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(24, 24, 4, 2, 1), # 24x24x24
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(24, 24, 4, 2, 1), # 48x48x12
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(24, channels, 4, 2, 1), # 96x96x4
)
# self.net6 = nn.Sequential(
# nn.Conv2d(latent_dim, 96, 1), # 1x1x96
# nn.LeakyReLU(0.2),
# nn.ConvTranspose2d(96, 48, 3, 1), # 3x3x48
# nn.LeakyReLU(0.2),
# nn.ConvTranspose2d(48, 48, 3, 2, 1), # 6x6
# nn.LeakyReLU(0.2),
# nn.ConvTranspose2d(48, 24, 3, 2, 1), # 12x12
# nn.LeakyReLU(0.2),
# nn.ConvTranspose2d(24, 24, 3, 2, 1), # 24x24x24
# nn.LeakyReLU(0.2),
# nn.ConvTranspose2d(24, 24, 3, 2, 1), # 48x48x12
# nn.LeakyReLU(0.2),
# nn.ConvTranspose2d(24, channels, 3, 2, 1), # 96x96x4
# )
def decode_sep(self, x):
return None
def decode(self, z):
z = z.view(-1, z.shape[1], 1, 1)
z = self.net6(z)
return z
class ConvDecPend(nn.Module):
def __init__(self, latent_dim, channels=3, out_dim=None):
super().__init__()
self.concept = 4
self.z_dim = latent_dim
self.z1_dim = self.z_dim // self.concept
self.net1 = ConvDecoderPend(self.z1_dim, channels)
self.net2 = ConvDecoderPend(self.z1_dim, channels)
self.net3 = ConvDecoderPend(self.z1_dim, channels)
self.net4 = ConvDecoderPend(self.z1_dim, channels)
self.net5 = nn.Sequential(
nn.Linear(16, 512),
nn.BatchNorm1d(512),
nn.Linear(512, 1024),
nn.BatchNorm1d(1024)
)
self.net6 = nn.Sequential(
nn.Conv2d(16, 128, 1), # 4x4
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(128, 64, 4), # 1x1
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(64, 64, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(64, 32, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(32, 32, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(32, 32, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(32, 3, 4, 2, 1)
)
def decode_sep(self, z, u=None, y=None):
z = z.view(-1, self.concept * self.z1_dim) # 16x64
zy = z if y is None else torch.cat((z, y), dim=1)
# print(zy.shape)
zy1, zy2, zy3, zy4 = torch.split(zy, self.z_dim // self.concept, dim=1) # each is 16x16
# print(zy1.shape)
# sys.exit(0)
rx1 = self.net1.decode(zy1)
# print(rx1.shape)
# sys.exit(0)
# print(f"Hi: {rx1.size()}")
rx2 = self.net2.decode(zy2)
rx3 = self.net3.decode(zy3)
rx4 = self.net4.decode(zy4)
# z = torch.cat((rx1, rx2, rx3, rx4), dim=0)
z = (rx1+rx2+rx3+rx4)/4
# print(z.shape)
# sys.exit(0)
return z, z, z, z, z
def decode(self, z, u=None, y=None):
z = z.view(-1, self.concept * self.z1_dim, 1, 1)
z = self.net6(z)
print(z.size())
return z