File size: 6,167 Bytes
6e7d4ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from functools import reduce
import torch
import torch.nn.functional as F
from torch_scatter import scatter_mean, scatter_add

from src.utils import bvm


class LinearSchedule:
    """
    We use the scheduling parameter \beta to linearly remove noise, i.e.
    \bar{\beta}_t = 1 - h (h: step size) with
    \bar{Q}_t = \bar{\beta}_t I + (1 - \bar{\beta}_t) 1_vec z1^T

    From this, it follows that for each step transition matrix, we have
    \beta_t = \bar{\beta}_t / \bar{\beta}_{t-h} = \frac{1-t}{1-t+h}
    """
    def __init__(self):
        super().__init__()

    def beta_bar(self, t):
        return 1 - t

    def beta(self, t, step_size):
        return (1 - t) / (1 - t + step_size)


class UniformPriorMarkovBridge:
    """
    Markov bridge model in which z0 is drawn from a uniform prior.
    Transitions are defined as:
    Q_t = \beta_t I + (1 - \beta_t) 1_vec z1^T
    where z1 is a one-hot representation of the final state.
    We follow the notation from [1] and multiply transition matrices from the
    right to one-hot state vectors.

    We use the scheduling parameter \beta to linearly remove noise, i.e.
    \bar{\beta}_t = 1 - h (h: step size) with
    \bar{Q}_t = \bar{\beta}_t I + (1 - \bar{\beta}_t) 1_vec z1^T

    From this, it follows that for each step transition matrix, we have
    \beta_t = \bar{\beta}_t / \bar{\beta}_{t-h} = \frac{1-t}{1-t+h}

    [1] Austin, Jacob, et al.
    "Structured denoising diffusion models in discrete state-spaces."
    Advances in Neural Information Processing Systems 34 (2021): 17981-17993.
    """
    def __init__(self, dim, loss_type='CE', step_size=None):
        assert loss_type in ['VLB', 'CE']
        self.dim = dim
        self.step_size = step_size  # required for VLB
        self.schedule = LinearSchedule()
        self.loss_type = loss_type
        super(UniformPriorMarkovBridge, self).__init__()

    @staticmethod
    def sample_categorical(p):
        """
        Sample from categorical distribution defined by probabilities 'p'
        :param p: (n, dim)
        :return: one-hot encoded samples (n, dim)
        """
        sampled = torch.multinomial(p, 1).squeeze(-1)
        return F.one_hot(sampled, num_classes=p.size(1)).float()

    def p_z0(self, batch_mask):
        return torch.ones((len(batch_mask), self.dim), device=batch_mask.device) / self.dim

    def sample_z0(self, batch_mask):
        """ Prior. """
        z0 = self.sample_categorical(self.p_z0(batch_mask))
        return z0

    def p_zt(self, z0, z1, t, batch_mask):
        Qt_bar = self.get_Qt_bar(t, z1, batch_mask)
        return bvm(z0, Qt_bar)

    def sample_zt(self, z0, z1, t, batch_mask):
        zt = self.sample_categorical(self.p_zt(z0, z1, t, batch_mask))
        return zt

    def p_zt_given_zs_and_z1(self, zs, z1, s, t, batch_mask):
        # 'z1' are one-hot "probabilities" for each class
        Qt = self.get_Qt(t, s, z1, batch_mask)
        # from pdb import set_trace; set_trace()
        q_zs_given_zt = bvm(zs, Qt)
        return q_zs_given_zt

    def p_zt_given_zs(self, zs, p_z1_hat, s, t, batch_mask):
        """
        Note that x can also represent a categorical distribution to compute
        transitions more efficiently at sampling time:
        p(z_t|z_s) = \sum_{\hat{z}_1} p(z_t | z_s, \hat{z}_1) * p(\hat{z}_1 | z_s)
                   = \sum_i z_s (\beta_t I + (1 - \beta_t) 1_vec z1_i^T) * \hat{p}_i
                   = \beta_t z_s I + (1 - \beta_t) z_s 1_vec \hat{p}^t
        """
        return self.p_zt_given_zs_and_z1(zs, p_z1_hat, s, t, batch_mask)

    def sample_zt_given_zs(self, zs, z1_logits, s, t, batch_mask):
        p_z1 = z1_logits.softmax(dim=-1)
        zt = self.sample_categorical(self.p_zt_given_zs(zs, p_z1, s, t, batch_mask))
        return zt

    def compute_loss(self, pred_logits, zs, z1, batch_mask, s, t, reduce='mean'):
        """ Compute loss per sample. """
        assert reduce in {'mean', 'sum', 'none'}

        if self.loss_type == 'CE':
            loss = F.cross_entropy(pred_logits, z1, reduction='none')

        else:  # VLB
            true_p_zs = self.p_zt_given_zs_and_z1(zs, z1, s, t, batch_mask)
            pred_p_zs = self.p_zt_given_zs(zs, pred_logits.softmax(dim=-1), s, t, batch_mask)
            loss = F.kl_div(pred_p_zs.log(), true_p_zs, reduction='none').sum(dim=-1)

        if reduce == 'mean':
            loss = scatter_mean(loss, batch_mask, dim=0)
        elif reduce == 'sum':
            loss = scatter_add(loss, batch_mask, dim=0)

        return loss

    def get_Qt(self, t, s, z1, batch_mask):
        """ Returns one-step transition matrix from step s to step t. """

        beta_t_given_s = self.schedule.beta(t, t - s)
        beta_t_given_s = beta_t_given_s.unsqueeze(-1)[batch_mask]

        # Q_t = beta_t * I + (1 - beta_t) * ones (dot) z1^T
        Qt = beta_t_given_s * torch.eye(self.dim, device=t.device).unsqueeze(0) + \
             (1 - beta_t_given_s) * z1.unsqueeze(1)
             # (1 - beta_t_given_s) * (torch.ones(self.dim, 1, device=t.device) @ z1)

        # assert (Qt.sum(-1) == 1).all()

        return Qt

    def get_Qt_bar(self, t, z1, batch_mask):
        """ Returns transition matrix from step 0 to step t. """

        beta_bar_t = self.schedule.beta_bar(t)
        beta_bar_t = beta_bar_t.unsqueeze(-1)[batch_mask]

        # Q_t_bar = beta_bar * I + (1 - beta_bar) * ones (dot) z1^T
        Qt_bar = beta_bar_t * torch.eye(self.dim, device=t.device).unsqueeze(0) + \
                 (1 - beta_bar_t) * z1.unsqueeze(1)
                 # (1 - beta_bar_t) * (torch.ones(self.dim, 1, device=t.device) @ z1)

        # assert (Qt_bar.sum(-1) == 1).all()

        return Qt_bar


class MarginalPriorMarkovBridge(UniformPriorMarkovBridge):
    def __init__(self, dim, prior_p, loss_type='CE', step_size=None):
        self.prior_p = prior_p
        print('Marginal Prior MB')
        super(MarginalPriorMarkovBridge, self).__init__(dim, loss_type, step_size)

    def p_z0(self, batch_mask):
        device = batch_mask.device
        p = torch.ones((len(batch_mask), self.dim), device=device) * self.prior_p.view(1, -1).to(device)
        return p