File size: 6,167 Bytes
6e7d4ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
from functools import reduce
import torch
import torch.nn.functional as F
from torch_scatter import scatter_mean, scatter_add
from src.utils import bvm
class LinearSchedule:
"""
We use the scheduling parameter \beta to linearly remove noise, i.e.
\bar{\beta}_t = 1 - h (h: step size) with
\bar{Q}_t = \bar{\beta}_t I + (1 - \bar{\beta}_t) 1_vec z1^T
From this, it follows that for each step transition matrix, we have
\beta_t = \bar{\beta}_t / \bar{\beta}_{t-h} = \frac{1-t}{1-t+h}
"""
def __init__(self):
super().__init__()
def beta_bar(self, t):
return 1 - t
def beta(self, t, step_size):
return (1 - t) / (1 - t + step_size)
class UniformPriorMarkovBridge:
"""
Markov bridge model in which z0 is drawn from a uniform prior.
Transitions are defined as:
Q_t = \beta_t I + (1 - \beta_t) 1_vec z1^T
where z1 is a one-hot representation of the final state.
We follow the notation from [1] and multiply transition matrices from the
right to one-hot state vectors.
We use the scheduling parameter \beta to linearly remove noise, i.e.
\bar{\beta}_t = 1 - h (h: step size) with
\bar{Q}_t = \bar{\beta}_t I + (1 - \bar{\beta}_t) 1_vec z1^T
From this, it follows that for each step transition matrix, we have
\beta_t = \bar{\beta}_t / \bar{\beta}_{t-h} = \frac{1-t}{1-t+h}
[1] Austin, Jacob, et al.
"Structured denoising diffusion models in discrete state-spaces."
Advances in Neural Information Processing Systems 34 (2021): 17981-17993.
"""
def __init__(self, dim, loss_type='CE', step_size=None):
assert loss_type in ['VLB', 'CE']
self.dim = dim
self.step_size = step_size # required for VLB
self.schedule = LinearSchedule()
self.loss_type = loss_type
super(UniformPriorMarkovBridge, self).__init__()
@staticmethod
def sample_categorical(p):
"""
Sample from categorical distribution defined by probabilities 'p'
:param p: (n, dim)
:return: one-hot encoded samples (n, dim)
"""
sampled = torch.multinomial(p, 1).squeeze(-1)
return F.one_hot(sampled, num_classes=p.size(1)).float()
def p_z0(self, batch_mask):
return torch.ones((len(batch_mask), self.dim), device=batch_mask.device) / self.dim
def sample_z0(self, batch_mask):
""" Prior. """
z0 = self.sample_categorical(self.p_z0(batch_mask))
return z0
def p_zt(self, z0, z1, t, batch_mask):
Qt_bar = self.get_Qt_bar(t, z1, batch_mask)
return bvm(z0, Qt_bar)
def sample_zt(self, z0, z1, t, batch_mask):
zt = self.sample_categorical(self.p_zt(z0, z1, t, batch_mask))
return zt
def p_zt_given_zs_and_z1(self, zs, z1, s, t, batch_mask):
# 'z1' are one-hot "probabilities" for each class
Qt = self.get_Qt(t, s, z1, batch_mask)
# from pdb import set_trace; set_trace()
q_zs_given_zt = bvm(zs, Qt)
return q_zs_given_zt
def p_zt_given_zs(self, zs, p_z1_hat, s, t, batch_mask):
"""
Note that x can also represent a categorical distribution to compute
transitions more efficiently at sampling time:
p(z_t|z_s) = \sum_{\hat{z}_1} p(z_t | z_s, \hat{z}_1) * p(\hat{z}_1 | z_s)
= \sum_i z_s (\beta_t I + (1 - \beta_t) 1_vec z1_i^T) * \hat{p}_i
= \beta_t z_s I + (1 - \beta_t) z_s 1_vec \hat{p}^t
"""
return self.p_zt_given_zs_and_z1(zs, p_z1_hat, s, t, batch_mask)
def sample_zt_given_zs(self, zs, z1_logits, s, t, batch_mask):
p_z1 = z1_logits.softmax(dim=-1)
zt = self.sample_categorical(self.p_zt_given_zs(zs, p_z1, s, t, batch_mask))
return zt
def compute_loss(self, pred_logits, zs, z1, batch_mask, s, t, reduce='mean'):
""" Compute loss per sample. """
assert reduce in {'mean', 'sum', 'none'}
if self.loss_type == 'CE':
loss = F.cross_entropy(pred_logits, z1, reduction='none')
else: # VLB
true_p_zs = self.p_zt_given_zs_and_z1(zs, z1, s, t, batch_mask)
pred_p_zs = self.p_zt_given_zs(zs, pred_logits.softmax(dim=-1), s, t, batch_mask)
loss = F.kl_div(pred_p_zs.log(), true_p_zs, reduction='none').sum(dim=-1)
if reduce == 'mean':
loss = scatter_mean(loss, batch_mask, dim=0)
elif reduce == 'sum':
loss = scatter_add(loss, batch_mask, dim=0)
return loss
def get_Qt(self, t, s, z1, batch_mask):
""" Returns one-step transition matrix from step s to step t. """
beta_t_given_s = self.schedule.beta(t, t - s)
beta_t_given_s = beta_t_given_s.unsqueeze(-1)[batch_mask]
# Q_t = beta_t * I + (1 - beta_t) * ones (dot) z1^T
Qt = beta_t_given_s * torch.eye(self.dim, device=t.device).unsqueeze(0) + \
(1 - beta_t_given_s) * z1.unsqueeze(1)
# (1 - beta_t_given_s) * (torch.ones(self.dim, 1, device=t.device) @ z1)
# assert (Qt.sum(-1) == 1).all()
return Qt
def get_Qt_bar(self, t, z1, batch_mask):
""" Returns transition matrix from step 0 to step t. """
beta_bar_t = self.schedule.beta_bar(t)
beta_bar_t = beta_bar_t.unsqueeze(-1)[batch_mask]
# Q_t_bar = beta_bar * I + (1 - beta_bar) * ones (dot) z1^T
Qt_bar = beta_bar_t * torch.eye(self.dim, device=t.device).unsqueeze(0) + \
(1 - beta_bar_t) * z1.unsqueeze(1)
# (1 - beta_bar_t) * (torch.ones(self.dim, 1, device=t.device) @ z1)
# assert (Qt_bar.sum(-1) == 1).all()
return Qt_bar
class MarginalPriorMarkovBridge(UniformPriorMarkovBridge):
def __init__(self, dim, prior_p, loss_type='CE', step_size=None):
self.prior_p = prior_p
print('Marginal Prior MB')
super(MarginalPriorMarkovBridge, self).__init__(dim, loss_type, step_size)
def p_z0(self, batch_mask):
device = batch_mask.device
p = torch.ones((len(batch_mask), self.dim), device=device) * self.prior_p.view(1, -1).to(device)
return p
|