cs2229 / codebase /utils.py
pltnhan07's picture
Add files using upload-large-folder tool
3d7e366 verified
import numpy as np
import os
import shutil
import torch
from torch.nn import functional as F
from torchvision import datasets, transforms
from torch.utils import data
import torch.utils.data as Data
from torch.distributions.multivariate_normal import MultivariateNormal
from PIL import Image
import math
# device = torch.device("cuda:5" if(torch.cuda.is_available()) else "cpu")
bce = torch.nn.BCEWithLogitsLoss(reduction='none')
bce3 = torch.nn.BCELoss(reduction='none')
def mask_threshold(x):
x = (x+0.5).int().float()
return x
def label_cov(labels):
cov = torch.from_numpy(np.cov(labels, rowvar = False)).to(device)
return cov
def get_labelcov_prior(batchsize, cov):
#print(cov)
v = torch.zeros(batchsize, cov.size()[0], cov.size()[1])
for i in range(batchsize):
v[i] = cov
mean = torch.zeros(batchsize, cov.size()[1])
return mean, v
def vector_expand(v):
V = torch.zeros(v.size()[0],v.size()[1],v.size()[1]).to(device)
for i in range(v.size()[0]):
for j in range(v.size()[1]):
V[i,j,j] = v[i,j]
return V
def block_matmul(a, b):
return None
def multivariate_sample(m,cov):
m = m.reshape(m.size()[0],4)
z = torch.zeros(m.size())
for i in range(z.size()[0]):
z[i] = MultivariateNormal(m[i].cpu(), cov[i].cpu()).sample()
return z.to(device)
def kl_multinormal_cov(qm,qv, pm, pv):
KL = torch.zeros(qm.size()[0]).to(device)
for i in range(qm.size()[0]):
#print(torch.det(qv[i].cpu()))
KL[i] = 0.5 * (torch.log(torch.det(pv[i])) - torch.log(torch.det(qv[i])) +
torch.trace(torch.inverse(pv[i]))*torch.trace(torch.inverse(qv[i])) +
torch.norm(qm[i])*torch.norm(pv[i], p=1))
return KL
def conditional_sample_gaussian(m,v):
sample = torch.randn(m.size()).to(m.device)
z = m + (v**0.5)*sample
return z
def gumbel_sample(m, v, temp=0.5):
# Sample from Gumbel
u = torch.rand_like(m)
g = - torch.log(- torch.log(u + v) + v)
# Gumbel-Softmax sample
z = F.softmax((m + g) / temp, dim=-1)
z = z.view(-1, 32 * 4)
return z
def gaussian_log_prob(samples, mean, var):
""" Returns the log probability of a specified Gaussian for a tensor of samples """
return -torch.log(var) - 0.5 * np.log(2*np.pi) - 0.5 * ((samples - mean))**2 / var
def condition_gaussian_parameters(h, dim = 1):
#print(h.size())
m, h = torch.split(h, h.size(1) // 2, dim=1)
m = torch.reshape(m, [-1, 3, 4])
h = torch.reshape(h, [-1, 3, 4])
v = F.softplus(h) + 1e-8
return m, v
def condition_prior(scale, label, dim):
mean = torch.ones(label.size()[0],label.size()[1], dim)
var = torch.ones(label.size()[0],label.size()[1], dim)
for i in range(label.size()[0]):
for j in range(label.size()[1]):
mul = (float(label[i][j])-scale[j][0])/(scale[j][1]-0)
mean[i][j] = torch.ones(dim)*mul
#mean[i][j] = torch.ones(dim)*label[i][j].detach().cpu()
# if j==2 or j == 3:
# mean[i][j] = torch.randn(dim)
var[i][j] = torch.ones(dim)*1
return mean, var
def causal_prior(scale, label, dim, A, mask=None):
mean = torch.ones(label.size()[0], label.size()[1], dim)
var = torch.ones(label.size()[0], label.size()[1], dim)
for i in range(label.size()[0]):
I = torch.eye(4).to(device)
inp = A.to(device).t() + I
# num_parents = torch.count_nonzero(inp, dim=1).reshape((dim, 1))
num_parents = torch.tensor([1, 1, 3, 3]).reshape((dim, 1)).to(device)
norm_label = (label[i].to(device) - torch.tensor(scale[:, 0]).to(device)) / (
torch.tensor(scale[:, 1]).to(device))
out = torch.matmul(inp.float(), norm_label.float().to(device)).reshape((dim, 1))
fin = torch.div(out, num_parents).repeat(1, dim)
# fin = out
mean[i] = fin
var[i] = torch.ones(dim, dim)
# if i == 0:
# print(label[i])
# print(mean[i])
# exit(0)
return mean, var
def compute_kl(z_1, z_2, logvar_1, logvar_2):
var_1 = torch.exp(logvar_1)
var_2 = torch.exp(logvar_2)
return var_1/var_2 + torch.square(z_2-z_1)/var_2 - 1 + logvar_2 - logvar_1
def scm_prior(scale, label, dim, A, mask=None):
mean = torch.ones(label.size()[0], label.size()[1], dim)
var = torch.ones(label.size()[0], label.size()[1], dim)
for i in range(label.size()[0]):
I = torch.eye(4).to(device)
inp = A.to(device).t() + I
num_parents = torch.count_nonzero(inp, dim=1).reshape((dim, 1))
# num_parents = torch.tensor([1, 1, 3, 3]).reshape((dim, 1)).to(device)
norm_label = (label[i].to(device) - torch.tensor(scale[:, 0]).to(device)) / (
torch.tensor(scale[:, 1]).to(device))
out = torch.matmul(inp.float(), norm_label.float().to(device)).reshape((dim, 1))
fin = torch.div(out, num_parents).repeat(1, dim)
# fin = out
mean[i] = fin
var[i] = torch.ones(dim, dim)
# if i == 0:
# print(label[i])
# print(mean[i])
# exit(0)
return mean, var
# Conditional Structure Prior
def structural_condition_prior(scale, label, dim, A): # CHANGE TO WHEN WE ARE INTERVENING ON 3RD OR 4TH CONCEPT TO HARDCODE IT
mean = torch.ones(label.size()[0], label.size()[1], dim)
var = torch.ones(label.size()[0], label.size()[1], dim)
I = torch.eye(4).to(label.device)
for i in range(label.size()[0]):
for j in range(label.size()[1]):
inp = A.to(label.device).t() + I
num_parents = torch.count_nonzero((inp), dim=1)
norm_label = (label[i].to(label.device) - torch.tensor(scale[:, 0]).to(label.device)) / (torch.tensor(scale[:, 1]).to(label.device))
mul = torch.matmul((inp).float(), norm_label.float().to(label.device))[j]
mul = mul / num_parents[j] # averaging
mean[i][j] = torch.ones(dim) * mul.item()
if j == 3:
mean[i][j] = torch.randn(dim)
var[i][j] = torch.ones(dim) * 1
return mean, var
def bce2(r, x):
return x * torch.log(r + 1e-7) + (1 - x) * torch.log(1 - r + 1e-7)
################################################################################
# Please familiarize yourself with the code below.
#
# Note that the notation is
# argument: argument_type: argument_shape
#
# Furthermore, the expected argument_shape is only a guideline. You're free to
# pass in inputs that violate the expected argument_shape provided you know
# what you're doing
################################################################################
def sample_multivariate(cov, loc = None):
# if loc == None:
# loc = torch.zeros((cov.shape[0], cov.shape[0]))
latent_code = torch.distributions.multivariate_normal.MultivariateNormal(loc, covariance_matrix=cov, precision_matrix=None, scale_tril=None, validate_args=None)
return latent_code
def get_covariance_matrix(A):
# requirements: A must be torcj
assert A.size()[1] == A.size()[2]
I = torch.zeros(A.size()).to(device)
i = torch.eye(n = A.size()[1]).to(device)
for j in range(A.size()[0]):
I[j] = torch.inverse(torch.mm(torch.t((A[j]-i)), (A[j]-i)))
return I
def sample_gaussian(m, v):
"""
Element-wise application reparameterization trick to sample from Gaussian
Args:
m: tensor: (batch, ...): Mean
v: tensor: (batch, ...): Variance
Return:
z: tensor: (batch, ...): Samples
"""
################################################################################
# TODO: Modify/complete the code here
# Sample z
################################################################################
################################################################################
# End of code modification
################################################################################
sample = torch.randn(m.shape).to(device)
z = m + (v**0.5)*sample
return z
def log_normal(x, m, var):
"""
Computes the elem-wise log probability of a Gaussian and then sum over the
last dim. Basically we're assuming all dims are batch dims except for the
last dim.
Args:
x: tensor: (batch, ..., dim): Observation
m: tensor: (batch, ..., dim): Mean
v: tensor: (batch, ..., dim): Variance
Return:
kl: tensor: (batch1, batch2, ...): log probability of each sample. Note
that the summation dimension (dim=-1) is not kept
"""
const = -0.5*x.size(-1)*torch.log(2*torch.tensor(np.pi))
#print(const.size())
log_det = -0.5*torch.sum(var, dim = -1)
# print(log_det)
# print(f'Variance: {m}')
log_exp = -0.5*torch.sum( (x - m)**2/torch.exp(var), dim = -1)
# print(f'Exp: {(x - m)**2}')
log_prob = const + log_det + log_exp
return log_prob
# def gaussian_log_prob(x, m, logvar):
# """
# Computes the elem-wise log probability of a Gaussian and then sum over the
# last dim. Basically we're assuming all dims are batch dims except for the
# last dim.
# Args:
# x: tensor: (batch, ..., dim): Observation
# m: tensor: (batch, ..., dim): Mean
# v: tensor: (batch, ..., dim): Variance
# Return:
# kl: tensor: (batch1, batch2, ...): log probability of each sample. Note
# that the summation dimension (dim=-1) is not kept
# """
# #print("q_m", m.size())
# #print("q_v", v.size())
# const = -0.5*x.size(-1)*torch.log(2*torch.tensor(np.pi))
# #print(const.size())
# log_det = -0.5*torch.sum(logvar, dim = -1)
# #print("log_det", log_det.size())
# log_exp = -0.5*torch.sum( (x - m)**2/torch.exp(logvar), dim = -1)
# log_prob = const + log_det + log_exp
# return log_prob
def log_gaussian(x, mu, log_var):
"""
Returns the log pdf of a normal distribution parametrised
by mu and log_var evaluated at x. (Univariate distribution)
:param x: point to evaluate
:param mu: mean of distribution
:param log_var: log variance of distribution
:return: log N(x|µ,σ)
"""
log_pdf = - 0.5 * np.log(2 * np.pi) - (log_var + 1e-8) / 2 - ((x - mu)**2 + 1e-8) / (2 * torch.exp(log_var))
# print('Size log_pdf:', log_pdf.shape)
return torch.sum(log_pdf, dim=-1)
def log_normal_mixture(z, m, v):
"""
Computes log probability of a uniformly-weighted Gaussian mixture.
Args:
z: tensor: (batch, dim): Observations
m: tensor: (batch, mix, dim): Mixture means
v: tensor: (batch, mix, dim): Mixture variances
Return:
log_prob: tensor: (batch,): log probability of each sample
"""
################################################################################
# TODO: Modify/complete the code here
# Compute the uniformly-weighted mixture of Gaussians density for each sample
# in the batch
################################################################################
z = z.unsqueeze(1)
log_probs = log_normal(z, m, v)
#print("log_probs_mix", log_probs.shape)
log_prob = log_mean_exp(log_probs, 1)
#print("log_prob_mix", log_prob.size())
################################################################################
# End of code modification
################################################################################
return log_prob
def gaussian_parameters(h, dim=-1):
"""
Converts generic real-valued representations into mean and variance
parameters of a Gaussian distribution
Args:
h: tensor: (batch, ..., dim, ...): Arbitrary tensor
dim: int: (): Dimension along which to split the tensor for mean and
variance
Returns:z
m: tensor: (batch, ..., dim / 2, ...): Mean
v: tensor: (batch, ..., dim / 2, ...): Variance
"""
m, h = torch.split(h, h.size(dim) // 2, dim=dim)
v = F.softplus(h) + 1e-8
return m, v
def log_bernoulli_with_logits(x, logits):
"""
Computes the log probability of a Bernoulli given its logits
Args:
x: tensor: (batch, dim): Observation
logits: tensor: (batch, dim): Bernoulli logits
Return:
log_prob: tensor: (batch,): log probability of each sample
"""
log_prob = -bce(input=logits, target=x).sum(-1)
return log_prob
def cross_entropy(x, logits):
"""
Computes the log probability of a Bernoulli given its logits
Args:
x: tensor: (batch, dim): Observation
logits: tensor: (batch, dim): Bernoulli logits
Return:
log_prob: tensor: (batch,): log probability of each sample
"""
log_prob = -bce(input=logits, target=x).sum(-1)
return log_prob
def log_bernoulli_with_logits_nosigmoid(x, logits):
"""
Computes the log probability of a Bernoulli given its logits
Args:
x: tensor: (batch, dim): Observation
logits: tensor: (batch, dim): Bernoulli logits
Return:
log_prob: tensor: (batch,): log probability of each sample
"""
log_prob = bce2(logits, x).sum(-1)
return log_prob
def kl_cat(q, log_q, log_p):
"""
Computes the KL divergence between two categorical distributions
Args:
q: tensor: (batch, dim): Categorical distribution parameters
log_q: tensor: (batch, dim): Log of q
log_p: tensor: (batch, dim): Log of p
Return:
kl: tensor: (batch,) kl between each sample
"""
element_wise = (q * (log_q - log_p))
kl = element_wise.sum(-1)
return kl
def kl_normal(qm, qv, pm, pv):
"""
Computes the elem-wise KL divergence between two normal distributions KL(q || p) and
sum over the last dimension
Args:
qm: tensor: (batch, dim): q mean
qv: tensor: (batch, dim): q variance
pm: tensor: (batch, dim): p mean
pv: tensor: (batch, dim): p variance
Return:
kl: tensor: (batch,): kl between each sample
"""
element_wise = 0.5 * (torch.log(pv) - torch.log(qv) + qv / pv + (qm - pm).pow(2) / pv - 1)
kl = element_wise.sum(-1)
#print("log var1", qv)
return kl
def log_prob(qm, qv, pm, pv):
"""
Computes the elem-wise KL divergence between two normal distributions KL(q || p) and
sum over the last dimension
Args:
qm: tensor: (batch, dim): q mean
qv: tensor: (batch, dim): q variance
pm: tensor: (batch, dim): p mean
pv: tensor: (batch, dim): p variance
Return:
kl: tensor: (batch,): kl between each sample
"""
element_wise = 0.5 * (pv - qv + torch.exp(qv) / torch.exp(pv) + (qm - pm).pow(2) / torch.exp(pv) - 1)
kl = element_wise.sum(-1)
#print("log var1", qv)
return kl
def duplicate(x, rep):
"""
Duplicates x along dim=0
Args:
x: tensor: (batch, ...): Arbitrary tensor
rep: int: (): Number of replicates. Setting rep=1 returns orignal x
z
Returns:
_: tensor: (batch * rep, ...): Arbitrary replicated tensor
"""
return x.expand(rep, *x.shape).reshape(-1, *x.shape[1:])
def log_mean_exp(x, dim):
"""
Compute the log(mean(exp(x), dim)) in a numerically stable manner
Args:
x: tensor: (...): Arbitrary tensor
dim: int: (): Dimension along which mean is computed
Return:
_: tensor: (...): log(mean(exp(x), dim))
"""
return log_sum_exp(x, dim) - np.log(x.size(dim))
def log_sum_exp(x, dim=0):
"""
Compute the log(sum(exp(x), dim)) in a numerically stable manner
Args:
x: tensor: (...): Arbitrary tensor
dim: int: (): Dimension along which sum is computed
Return:
_: tensor: (...): log(sum(exp(x), dim))
"""
max_x = torch.max(x, dim)[0]
new_x = x - max_x.unsqueeze(dim).expand_as(x)
return max_x + (new_x.exp().sum(dim)).log()
def load_model_by_name(model, global_step):
"""
Load a model based on its name model.name and the checkpoint iteration step
Args:
model: Model: (): A model
global_step: int: (): Checkpoint iteration
"""
file_path = os.path.join('checkpoints',
model.name,
'model-{:05d}.pt'.format(global_step))
print(file_path)
state = torch.load(file_path, map_location='cpu')
# print(state)
model.load_state_dict(state)
print("Loaded from {}".format(file_path))
################################################################################
# No need to read/understand code beyond this point. Unless you want to.
# But do you tho ¯\_(�?_/¯
################################################################################
def evaluate_lower_bound(model, labeled_test_subset, run_iwae=True):
check_model = isinstance(model, VAE) or isinstance(model, GMVAE) or isinstance(model, LVAE)
assert check_model, "This function is only intended for VAE and GMVAE"
print('*' * 80)
print("LOG-LIKELIHOOD LOWER BOUNDS ON TEST SUBSET")
print('*' * 80)
xl, _ = labeled_test_subset
torch.manual_seed(0)
xl = torch.bernoulli(xl)
def detach_torch_tuple(args):
return (v.detach() for v in args)
def compute_metrics(fn, repeat):
metrics = [0, 0, 0]
for _ in range(repeat):
niwae, kl, rec = detach_torch_tuple(fn(xl))
metrics[0] += niwae / repeat
metrics[1] += kl / repeat
metrics[2] += rec / repeat
return metrics
# Run multiple times to get low-var estimate
nelbo, kl, rec = compute_metrics(model.negative_elbo_bound, 100)
print("NELBO: {}. KL: {}. Rec: {}".format(nelbo, kl, rec))
if run_iwae:
for iw in [1, 10, 100, 1000]:
repeat = max(100 // iw, 1) # Do at least 100 iterations
fn = lambda x: model.negative_iwae_bound(x, iw)
niwae, kl, rec = compute_metrics(fn, repeat)
print("Negative IWAE-{}: {}".format(iw, niwae))
def evaluate_classifier(model, test_set):
check_model = isinstance(model, SSVAE)
assert check_model, "This function is only intended for SSVAE"
print('*' * 80)
print("CLASSIFICATION EVALUATION ON ENTIRE TEST SET")
print('*' * 80)
X, y = test_set
pred = model.cls.classify(X)
accuracy = (pred.argmax(1) == y).float().mean()
print("Test set classification accuracy: {}".format(accuracy))
def save_model_by_name(model, global_step):
save_dir = os.path.join('checkpoints', model.name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
file_path = os.path.join(save_dir, 'model-{:05d}.pt'.format(global_step))
state = model.state_dict()
torch.save(state, file_path)
print('Saved to {}'.format(file_path))
def prepare_writer(model_name, overwrite_existing=False):
log_dir = os.path.join('logs', model_name)
save_dir = os.path.join('checkpoints', model_name)
if overwrite_existing:
delete_existing(log_dir)
delete_existing(save_dir)
# Sadly, I've been told *not* to use tensorflow :<
# writer = tf.summary.FileWriter(log_dir)
writer = None
return writer
def log_summaries(writer, summaries, global_step):
pass # Sad :<
# for tag in summaries:
# val = summaries[tag]
# tf_summary = tf.Summary.Value(tag=tag, simple_value=val)
# writer.add_summary(tf.Summary(value=[tf_summary]), global_step)
# writer.flush()
def delete_existing(path):
if os.path.exists(path):
print("Deleting existing path: {}".format(path))
shutil.rmtree(path)
def reset_weights(m):
try:
m.reset_parameters()
except AttributeError:
pass
def get_mnist_data(device, use_test_subset=True):
preprocess = transforms.ToTensor()
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True, transform=preprocess),
batch_size=100,
shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False, download=True, transform=preprocess),
batch_size=100,
shuffle=True)
# Create pre-processed training and test sets
X_train = train_loader.dataset.train_data.to(device).reshape(-1, 784).float() / 255
y_train = train_loader.dataset.train_labels.to(device)
X_test = test_loader.dataset.test_data.to(device).reshape(-1, 784).float() / 255
y_test = test_loader.dataset.test_labels.to(device)
# Create supervised subset (deterministically chosen)
# This subset will serve dual purpose of log-likelihood evaluation and
# semi-supervised learning. Pretty hacky. Don't judge :<
X = X_test if use_test_subset else X_train
y = y_test if use_test_subset else y_train
xl, yl = [], []
for i in range(10):
idx = y == i
idx_choice = get_mnist_index(i, test=use_test_subset)
xl += [X[idx][idx_choice]]
yl += [y[idx][idx_choice]]
xl = torch.cat(xl).to(device)
yl = torch.cat(yl).to(device)
yl = yl.new(np.eye(10)[yl])
labeled_subset = (xl, yl)
return train_loader, labeled_subset, (X_test, y_test)
def get_mnist_index(i, test=True):
# Obviously *hand*-coded
train_idx = np.array([[2732,2607,1653,3264,4931,4859,5827,1033,4373,5874],
[5924,3468,6458,705,2599,2135,2222,2897,1701,537],
[2893,2163,5072,4851,2046,1871,2496,99,2008,755],
[797,659,3219,423,3337,2745,4735,544,714,2292],
[151,2723,3531,2930,1207,802,2176,2176,1956,3622],
[3560,756,4369,4484,1641,3114,4984,4353,4071,4009],
[2105,3942,3191,430,4187,2446,2659,1589,2956,2681],
[4180,2251,4420,4870,1071,4735,6132,5251,5068,1204],
[3918,1167,1684,3299,2767,2957,4469,560,5425,1605],
[5795,1472,3678,256,3762,5412,1954,816,2435,1634]])
test_idx = np.array([[684,559,629,192,835,763,707,359,9,723],
[277,599,1094,600,314,705,551,87,174,849],
[537,845,72,777,115,976,755,448,850,99],
[984,177,755,797,659,147,910,423,288,961],
[265,697,639,544,543,714,244,151,675,510],
[459,882,183,28,802,128,128,53,550,488],
[756,273,335,388,617,42,442,543,888,257],
[57,291,779,430,91,398,611,908,633,84],
[203,324,774,964,47,639,131,972,868,180],
[1000,846,143,660,227,954,791,719,909,373]])
if test:
return test_idx[i]
else:
return train_idx[i]
def get_svhn_data(device):
preprocess = transforms.ToTensor()
train_loader = torch.utils.data.DataLoader(
datasets.SVHN('data', split='extra', download=True, transform=preprocess),
batch_size=100,
shuffle=True)
return train_loader, (None, None), (None, None)
def gumbel_softmax(logits, tau, eps=1e-8):
U = torch.rand_like(logits)
gumbel = -torch.log(-torch.log(U + eps) + eps)
y = logits + gumbel
y = F.softmax(y / tau, dim=1)
return y
class DeterministicWarmup(object):
"""
Linear deterministic warm-up as described in
[Sønderby 2016].
"""
def __init__(self, n=100, t_max=1):
self.t = 0
self.t_max = t_max
self.inc = 1/n
def __iter__(self):
return self
def __next__(self):
t = self.t + self.inc
self.t = self.t_max if t > self.t_max else t
return self.t
class FixedSeed:
def __init__(self, seed):
self.seed = seed
self.state = None
def __enter__(self):
self.state = np.random.get_state()
np.random.seed(self.seed)
def __exit__(self, exc_type, exc_value, traceback):
np.random.set_state(self.state)