File size: 10,035 Bytes
f7009b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
import torch
from abc import ABC
from torch import nn
from torch.nn import functional as F
from .diffusion import DiffusionLoss, DDIMSampler, DDPMSampler
from .transformer import TransformerModel
from .mamba import MambaModel
from .lstm import LstmModel
from .gatemlp import GMLPModel




class ModelDiffusion(nn.Module, ABC):
    config = {}

    def __init__(self, sequence_length):
        super().__init__()
        DiffusionLoss.config = self.config
        self.criteria = DiffusionLoss()
        if self.config.get("post_d_model") is None:
            assert self.config["d_model"] == self.config["condition_dim"]
        self.sequence_length = sequence_length
        # to define model after this function
        self.to_condition = nn.Linear(self.config["d_condition"], self.config["d_model"])
        self.to_permutation_state = nn.Embedding(self.config["num_permutation"], self.config["d_model"])
        self.to_permutation_state.weight = \
                nn.Parameter(torch.ones_like(self.to_permutation_state.weight) / self.config["d_model"])

    def forward(self, output_shape=None, x_0=None, condition=None, permutation_state=None, **kwargs):
        # condition
        if condition is not None:
            assert len(condition.shape) == 2
            assert condition.shape[-1] == self.config["d_condition"]
            condition = self.to_condition(condition.to(self.device)[:, None, :])
        else:  # not use condition
            condition = self.to_condition(torch.zeros(size=(1, 1, 1), device=self.device))
        # process
        if kwargs.get("sample"):
            if permutation_state is not False:
                permutation_state = torch.randint(0, self.to_permutation_state.num_embeddings, (1,), device=self.device)
                permutation_state = self.to_permutation_state(permutation_state)[:, None, :]
            else:  # permutation state == False
                permutation_state = 0.
            return self.sample(x=None, condition=condition+permutation_state)
        else:  # train
            if permutation_state is not None:
                permutation_state = self.to_permutation_state(permutation_state)[:, None, :]
            else:  # not use permutation state
                permutation_state = 0.
            # Given condition c and ground truth token x, compute loss
            c = self.model(output_shape, condition+permutation_state)
            loss = self.criteria(x=x_0, c=c, **kwargs)
            return loss

    @torch.no_grad()
    def sample(self, x=None, condition=None):
        z = self.model([1, self.sequence_length, self.config["d_model"]], condition)
        if x is None:
            x = torch.randn((1, self.sequence_length, self.config["model_dim"]), device=z.device)
        x = self.criteria.sample(x, z)
        return x

    @property
    def device(self):
        return next(self.parameters()).device


class ModelMSELoss(nn.Module, ABC):
    config = {}

    def __init__(self, sequence_length):
        super().__init__()
        if self.config.get("post_d_model") is None:
            assert self.config["d_model"] == self.config["condition_dim"]
        self.sequence_length = sequence_length
        # to define model after this function
        self.to_condition = nn.Linear(self.config["d_condition"], self.config["d_model"])
        self.to_permutation_state = nn.Embedding(self.config["num_permutation"], self.config["d_model"])
        self.to_permutation_state.weight = \
                nn.Parameter(torch.ones_like(self.to_permutation_state.weight) / self.config["d_model"])

    def forward(self, output_shape=None, x_0=None, condition=None, permutation_state=None, **kwargs):
        # condition
        if condition is not None:
            assert len(condition.shape) == 2
            assert condition.shape[-1] == self.config["d_condition"]
            condition = self.to_condition(condition.to(self.device)[:, None, :])
        else:  # not use condition
            condition = self.to_condition(torch.zeros(size=(1, 1, 1), device=self.device))
        # process
        if kwargs.get("sample"):
            if permutation_state is not False:
                permutation_state = torch.randint(0, self.to_permutation_state.num_embeddings, (1,), device=self.device)
                permutation_state = self.to_permutation_state(permutation_state)[:, None, :]
            else:  # permutation state == False
                permutation_state = 0.
            return self.sample(x=None, condition=condition+permutation_state)
        else:  # train
            if permutation_state is not None:
                permutation_state = self.to_permutation_state(permutation_state)[:, None, :]
            else:  # not use permutation state
                permutation_state = 0.
            # Given condition c and ground truth token x, compute loss
            c = self.model(output_shape, condition+permutation_state)
            assert c.shape[-1] == x_0.shape[-1], "d_model should be equal to dim_per_token"
            # preprocess nan to zero
            mask = torch.isnan(x_0)
            x_0 = torch.nan_to_num(x_0, 0.)
            # get the gradient
            loss = F.mse_loss(c, x_0, reduction="none")
            loss[mask] = torch.nan
            return loss.nanmean()

    @torch.no_grad()
    def sample(self, x=None, condition=None):
        z = self.model([1, self.sequence_length, self.config["d_model"]], condition)
        return z

    @property
    def device(self):
        return next(self.parameters()).device




class MambaDiffusion(ModelDiffusion):
    def __init__(self, sequence_length, positional_embedding):
        super().__init__(sequence_length=sequence_length)
        MambaModel.config = self.config
        self.model = MambaModel(positional_embedding=positional_embedding)


class TransformerDiffusion(ModelDiffusion):
    def __init__(self, sequence_length, positional_embedding):
        super().__init__(sequence_length=sequence_length)
        TransformerModel.config = self.config
        self.model = TransformerModel(positional_embedding=positional_embedding)


class LstmDiffusion(ModelDiffusion):
    def __init__(self, sequence_length, positional_embedding):
        super().__init__(sequence_length=sequence_length)
        LstmModel.config = self.config
        self.model = LstmModel(positional_embedding=positional_embedding)


class GMLPDiffusion(ModelDiffusion):
    def __init__(self, sequence_length, positional_embedding):
        super().__init__(sequence_length=sequence_length)
        GMLPModel.config = self.config
        self.model = GMLPModel(positional_embedding=positional_embedding)




class MambaMSELoss(ModelMSELoss):
    def __init__(self, sequence_length, positional_embedding):
        super().__init__(sequence_length=sequence_length)
        MambaModel.config = self.config
        self.model = MambaModel(positional_embedding=positional_embedding)




class ClassConditionMambaDiffusion(MambaDiffusion):
    def __init__(self, sequence_length, positional_embedding, input_class=10):
        super().__init__(sequence_length, positional_embedding)
        self.get_condition = nn.Sequential(
            nn.Linear(input_class, self.config["d_condition"]),
            nn.SiLU(),
        )  # to condition
        self.to_permutation_state = nn.Embedding(self.config["num_permutation"], self.config["d_model"])
        # condition module
        self.to_condition_linear = nn.Linear(self.config["d_condition"], self.config["d_model"])
        to_condition_gate = torch.zeros(size=(1, sequence_length, 1))
        to_condition_gate[:, -8:, :] = 1.
        self.register_buffer("to_condition_gate", to_condition_gate)
        # reset to_condition
        del self.to_condition
        self.to_condition = self._to_condition

    def forward(self, output_shape=None, x_0=None, condition=None, **kwargs):
        condition = self.get_condition(condition.to(self.device))
        return super().forward(output_shape=output_shape, x_0=x_0, condition=condition, **kwargs)

    def _to_condition(self, x):
        assert len(x.shape) == 3
        x = self.to_condition_linear(x)
        x = x * self.to_condition_gate
        return x


class ClassConditionMambaDiffusionFull(MambaDiffusion):
    def __init__(self, sequence_length, positional_embedding, input_class=10, init_noise_intensity=1e-4):
        super().__init__(sequence_length, positional_embedding)
        self.get_condition = nn.Sequential(
            nn.Linear(input_class, self.config["d_condition"]),
            nn.LayerNorm(self.config["d_condition"]),
        )  # to condition
        self.to_permutation_state = nn.Embedding(self.config["num_permutation"], self.config["d_model"])
        # condition module
        self.to_condition_linear = nn.Linear(self.config["d_condition"], self.config["d_model"])
        self.to_condition_conv = nn.Sequential(
            nn.Conv1d(1, sequence_length, 9, 1, 4),
            nn.GroupNorm(num_groups=1, num_channels=sequence_length),
            nn.Conv1d(sequence_length, sequence_length, 9, 1, 4),
        )  # [batch_size, sequence_length, d_model]
        # reset to_condition
        del self.to_condition

    def forward(self, output_shape=None, x_0=None, condition=None, **kwargs):
        if kwargs.get("pre_training"):
            self.to_condition = self._zero_condition
            condition = None
        else:  # train with condition
            self.to_condition = self._to_condition
            condition = self.get_condition(condition.to(self.device))
        return super().forward(output_shape=output_shape, x_0=x_0, condition=condition, **kwargs)

    def _to_condition(self, x):
        assert len(x.shape) == 3
        x = self.to_condition_linear(x)  # [batch_size, 1, d_model]
        x = self.to_condition_conv(x)  # [batch_size, sequence_length, d_model]
        return x

    def _zero_condition(self, x):
        return torch.zeros(size=(x.shape[0], self.sequence_length, self.config["d_model"]), device=x.device)