| import torch |
| import torch.nn as nn |
| from torch.distributions import MultivariateNormal, Normal, Uniform |
| import torch.nn.functional as F |
| |
| from codebase import utils as ut |
|
|
|
|
| class MLP(nn.Module): |
| """ a simple 4-layer MLP """ |
|
|
| def __init__(self, nin, nout, nh): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(nin, nh), |
| nn.ReLU(), |
| nn.Linear(nh, nh), |
| nn.ReLU(), |
| nn.Linear(nh, nout), |
| nn.Sigmoid(), |
| ) |
| def forward(self, x, mask): |
| return self.net(x * mask) |
| |
| |
| |
| class MultivariateCausalFlow(nn.Module): |
| def __init__(self, dim, k, C=None, net_class=MLP, nh=100, scale=True, shift=True): |
| super().__init__() |
| self.dim = dim |
| self.k = k |
|
|
| self.C = C |
| self.A = (torch.eye(self.C.shape[0]) - self.C) |
| |
| if scale: |
| self.s_cond = net_class(self.dim*self.k, self.k, 100) |
| if shift: |
| self.t_cond = net_class(self.dim*self.k, self.k, 100) |
| |
| self.z_int_prior = Normal(0.0, 1.0) |
|
|
| |
| def forward(self, e, target=None, value=None): |
|
|
| total_dims = e.shape[1]*e.shape[2] |
| log_det = torch.zeros(e.size(0)).to(e.device) |
| p_logprob = torch.zeros(e.size(0)).to(e.device) |
| batch_size = e.shape[0] |
| z = torch.zeros(batch_size, self.dim, self.k).to(e.device) |
| |
| |
| for i in range(self.dim): |
| if 1 in self.C[:, i]: |
| |
| mask = self.C[:, i].repeat(self.k, 1).T.reshape(total_dims).to(e.device) |
| elif 1 not in self.C[:, i] or target == i: |
| mask = torch.zeros(total_dims).to(e.device) |
| |
| |
| s = self.s_cond(z.reshape(-1, total_dims), mask).reshape(batch_size, self.k) |
| t = self.t_cond(z.reshape(-1, total_dims), mask).reshape(batch_size, self.k) |
|
|
| |
| z[:, i, :] = torch.exp(s) * e[:, i, :].reshape(batch_size, self.k) + t |
| if target is not None and value is not None: |
| |
| |
| |
| |
| |
| |
| z[:, target, :] = value |
| |
| log_det += torch.sum(s, dim=1) |
| |
| return z, log_det |
| |
| def backward(self, z, target=None, value=None): |
| |
| total_dims = z.shape[1]*z.shape[2] |
| log_det = torch.zeros(z.size(0)).to(z.device) |
| p_logprob = torch.zeros(z.size(0)).to(z.device) |
| batch_size = z.shape[0] |
| e = torch.zeros(batch_size, self.dim, self.k).to(z.device) |
| |
| |
| for i in range(self.dim): |
| |
| if 1 in self.C[:, i]: |
| |
| mask = self.C[:, i].repeat(self.k, 1).T.reshape(total_dims).to(e.device) |
| elif 1 not in self.C[:, i] or target == i: |
| mask = torch.zeros(total_dims).to(e.device) |
| |
| |
| s = self.s_cond(z.reshape(-1, total_dims), mask).reshape(batch_size, self.k) |
| t = self.t_cond(z.reshape(-1, total_dims), mask).reshape(batch_size, self.k) |
| |
| |
| |
| e[:, i, :] = torch.exp(-s) * (z[:, i, :].reshape(batch_size, self.k) - t) |
| |
| |
| |
| log_det -= torch.mean(s, dim=1) |
| |
| return z, log_det |
| |
| |
| |
| def forward_interv(self, e, I): |
| total_dims = e.shape[1]*e.shape[2] |
| log_det = torch.zeros(e.size(0)).to(e.device) |
| p_logprob = torch.zeros(e.size(0)).to(e.device) |
| batch_size = e.shape[0] |
| z = torch.zeros(batch_size, self.dim, self.dim).to(e.device) |
| |
| |
| for i in range(self.dim): |
| |
| interv_mask = (I[:, i] == 1).to(e.device) |
| |
| |
| if 1 in self.C[:, i]: |
| |
| mask = self.C[:, i].repeat(4, 1).T.reshape(total_dims).to(e.device) |
| else: |
| mask = torch.zeros(total_dims).to(e.device) |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| z[:, i, :] = torch.ones(1, 4).to(e.device) * 3 |
| |
| s = torch.where(interv_mask.reshape(batch_size, 4), |
| self.s_cond(z.reshape(-1, total_dims), torch.zeros(total_dims).to(e.device)).reshape(batch_size, self.dim), |
| self.s_cond(z.reshape(-1, total_dims), mask).reshape(batch_size, self.dim)) |
| |
| t = torch.where(interv_mask.reshape(batch_size, 4), |
| self.t_cond(z.reshape(-1, total_dims), torch.zeros(total_dims).to(e.device)).reshape(batch_size, self.dim), |
| self.t_cond(z.reshape(-1, total_dims), mask).reshape(batch_size, self.dim)) |
| |
| |
| |
| z[:, i, :] = torch.exp(s) * (e[:, i, :] - t) |
| |
|
|
| return z |
| |
|
|
| |
| |
| class PriorMultivariateCausalFlow(nn.Module): |
| def __init__(self, dim, k, C=None, net_class=MLP, nh=100, scale=True, shift=True): |
| super().__init__() |
| self.dim = dim |
| self.k = k |
|
|
| self.C = C |
| self.A = (torch.eye(self.C.shape[0]) - self.C) |
| |
| if scale: |
| self.s_cond = net_class(self.dim*self.k, self.k, 100) |
| if shift: |
| self.t_cond = net_class(self.dim*self.k, self.k, 100) |
| |
| self.z_int_prior = Normal(0.0, 1.0) |
|
|
| |
| def forward(self, e, latent=None, target=None, value=None): |
|
|
| total_dims = e.shape[1]*e.shape[2] |
| log_det = torch.zeros(e.size(0)).to(e.device) |
| p_logprob = torch.zeros(e.size(0)).to(e.device) |
| batch_size = e.shape[0] |
| z = torch.zeros(batch_size, self.dim, self.k).to(e.device) |
| |
| |
| for i in range(self.dim): |
| if 1 in self.C[:, i]: |
| |
| mask = self.C[:, i].repeat(self.k, 1).T.reshape(total_dims).to(e.device) |
| elif 1 not in self.C[:, i] or target == i: |
| mask = torch.zeros(total_dims).to(e.device) |
| |
| |
| s = self.s_cond(latent.reshape(-1, total_dims), mask).reshape(batch_size, self.k) |
| t = self.t_cond(latent.reshape(-1, total_dims), mask).reshape(batch_size, self.k) |
|
|
| |
| z[:, i, :] = torch.exp(s) * e[:, i, :].reshape(batch_size, self.k) + t |
| if target is not None and value is not None: |
| |
| |
| |
| |
| |
| |
| z[:, target, :] = value |
| |
| log_det += torch.sum(s, dim=1) |
| |
| return z, log_det |
| |
| def backward(self, z, target=None, value=None): |
| |
| total_dims = z.shape[1]*z.shape[2] |
| log_det = torch.zeros(z.size(0)).to(z.device) |
| p_logprob = torch.zeros(z.size(0)).to(z.device) |
| batch_size = z.shape[0] |
| e = torch.zeros(batch_size, self.dim, self.k).to(z.device) |
| |
| |
| for i in range(self.dim): |
| |
| if 1 in self.C[:, i]: |
| |
| mask = self.C[:, i].repeat(self.k, 1).T.reshape(total_dims).to(e.device) |
| elif 1 not in self.C[:, i] or target == i: |
| mask = torch.zeros(total_dims).to(e.device) |
| |
| |
| s = self.s_cond(z.reshape(-1, total_dims), mask).reshape(batch_size, self.k) |
| t = self.t_cond(z.reshape(-1, total_dims), mask).reshape(batch_size, self.k) |
| |
| |
| |
| e[:, i, :] = torch.exp(-s) * (z[:, i, :].reshape(batch_size, self.k) - t) |
| |
| |
| |
| log_det -= torch.mean(s, dim=1) |
| |
| return z, log_det |
| |
| |
| |
| def forward_interv(self, e, I): |
| total_dims = e.shape[1]*e.shape[2] |
| log_det = torch.zeros(e.size(0)).to(e.device) |
| p_logprob = torch.zeros(e.size(0)).to(e.device) |
| batch_size = e.shape[0] |
| z = torch.zeros(batch_size, self.dim, self.dim).to(e.device) |
| |
| |
| for i in range(self.dim): |
| |
| interv_mask = (I[:, i] == 1).to(e.device) |
| |
| |
| if 1 in self.C[:, i]: |
| |
| mask = self.C[:, i].repeat(4, 1).T.reshape(total_dims).to(e.device) |
| else: |
| mask = torch.zeros(total_dims).to(e.device) |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| z[:, i, :] = torch.ones(1, 4).to(e.device) * 3 |
| |
| s = torch.where(interv_mask.reshape(batch_size, 4), |
| self.s_cond(z.reshape(-1, total_dims), torch.zeros(total_dims).to(e.device)).reshape(batch_size, self.dim), |
| self.s_cond(z.reshape(-1, total_dims), mask).reshape(batch_size, self.dim)) |
| |
| t = torch.where(interv_mask.reshape(batch_size, 4), |
| self.t_cond(z.reshape(-1, total_dims), torch.zeros(total_dims).to(e.device)).reshape(batch_size, self.dim), |
| self.t_cond(z.reshape(-1, total_dims), mask).reshape(batch_size, self.dim)) |
| |
| |
| |
| z[:, i, :] = torch.exp(s) * (e[:, i, :] - t) |
| |
|
|
| return z |
|
|
| |
| |
| |
| class CausalAffineAutoregFlow(nn.Module): |
| def __init__(self, dim, C, net_class=MLP, nh=100, scale=True, shift=True): |
| super().__init__() |
| self.dim = dim |
| |
| |
| self.C = C |
| if scale: |
| self.s_cond = net_class(self.dim, 1, 100) |
| if shift: |
| self.t_cond = net_class(self.dim, 1, 100) |
| |
| self.z_int_prior = Normal(0.0, 1.0) |
| |
|
|
|
|
| |
| def forward(self, e): |
| log_det = torch.zeros(e.size(0)).to(device) |
| p_logprob = torch.zeros(e.size(0)).to(device) |
| batch_size = e.shape[0] |
| z = torch.zeros(e.shape).to(device) |
| |
| |
| |
| |
| for i in range(self.dim): |
| |
| if 1 in self.C[:, i]: |
| mask = self.C[:, i].reshape(self.dim).to(device) |
| else: |
| mask = torch.zeros(self.dim).to(device) |
| |
| |
| s = self.s_cond(z, mask).reshape(z.shape[0]) |
| t = self.t_cond(z, mask).reshape(z.shape[0]) |
| |
| |
| |
| |
| |
| z[:, i] = torch.exp(s) * e[:, i] + t |
| |
| |
| |
| |
| |
| |
| log_det += s |
| |
| return z, log_det |
| |
| |
| def backward(self, z, I, z_base_inf=None): |
| log_det = torch.zeros(z.size(0)).to(device) |
| p_logprob = torch.zeros(z.size(0)).to(device) |
| batch_size = z.shape[0] |
| e = torch.zeros(z.shape).to(device) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| for i in range(self.dim): |
| |
| interv_mask = (I[:, i] == 1).unsqueeze(-1).to(device) |
| |
| if 1 in self.C[:, i]: |
| mask = self.C[:, i].reshape(self.dim).to(device) |
| else: |
| mask = torch.zeros(self.dim).to(device) |
| |
| |
| |
| |
| |
| z_base = torch.randn(e[:, i].shape).to(device) |
| |
| if z_base_inf is not None: |
| z_base = z_base_inf |
|
|
| |
| |
| z[:, i] = torch.where(interv_mask.reshape(batch_size), z_base.clone(), z[:, i].clone()) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| s = torch.where(interv_mask.reshape(batch_size), |
| self.s_cond(z, torch.zeros(self.dim).to(device)).reshape(z.shape[0]), |
| self.s_cond(z, mask).reshape(z.shape[0])) |
| |
| t = torch.where(interv_mask.reshape(batch_size), |
| self.t_cond(z, torch.zeros(self.dim).to(device)).reshape(z.shape[0]), |
| self.t_cond(z, mask).reshape(z.shape[0])) |
| |
| |
| |
| e[:, i] = torch.exp(-s) * (z[:, i] - t) |
| |
| s_new = torch.where(interv_mask.reshape(batch_size), s.to(device), torch.zeros(s.shape).to(device)) |
| z_val = torch.where(interv_mask.reshape(batch_size), self.z_int_prior.log_prob(z_base).to(device), torch.zeros(z[:, i].shape).to(device)) |
| |
| |
| |
| |
| log_det -= s_new |
| p_logprob += z_val |
| |
| |
| |
| return e, p_logprob, log_det |
| |
|
|
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |