zirobtc commited on
Commit
0b85347
·
verified ·
1 Parent(s): 6eb0e3f

Delete models/diffloss.py

Browse files
Files changed (1) hide show
  1. models/diffloss.py +0 -258
models/diffloss.py DELETED
@@ -1,258 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from torch.utils.checkpoint import checkpoint
4
- import math
5
- from timm.layers.mlp import SwiGLU
6
- from models.diffusion import create_diffusion
7
-
8
-
9
- class DiffLoss(nn.Module):
10
- """Diffusion Loss"""
11
- def __init__(self, target_channels, z_channels, depth, width, num_sampling_steps, grad_checkpointing=False, learn_sigma=False):
12
- super(DiffLoss, self).__init__()
13
- self.in_channels = target_channels
14
- self.net = SimpleMLPAdaLN(
15
- in_channels=target_channels,
16
- model_channels=width,
17
- out_channels=target_channels * 2 if learn_sigma else target_channels,
18
- z_channels=z_channels,
19
- num_res_blocks=depth,
20
- grad_checkpointing=grad_checkpointing
21
- )
22
-
23
- self.train_diffusion = create_diffusion(timestep_respacing="", noise_schedule="cosine")
24
- self.gen_diffusion = create_diffusion(timestep_respacing=num_sampling_steps, noise_schedule="cosine")
25
-
26
- def forward(self, target, z, mask=None):
27
- t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device)
28
- model_kwargs = dict(c=z)
29
- loss_dict = self.train_diffusion.training_losses(self.net, target, t, model_kwargs)
30
- loss = loss_dict["loss"]
31
- pred_xstart = loss_dict["pred_xstart"]
32
- if mask is not None:
33
- loss = (loss * mask).sum() / mask.sum()
34
- return loss.mean(), pred_xstart
35
-
36
- def sample(self, z, temperature=1.0, cfg=1.0):
37
-
38
- if not cfg == 1.0:
39
- noise = torch.randn(z.shape[0] // 2, self.in_channels).to(z.device)
40
- noise = torch.cat([noise, noise], dim=0)
41
- model_kwargs = dict(c=z, cfg_scale=cfg)
42
- sample_fn = self.net.forward_with_cfg
43
- else:
44
- noise = torch.randn(z.shape[0], self.in_channels).to(z.device)
45
- model_kwargs = dict(c=z)
46
- sample_fn = self.net.forward
47
-
48
- sampled_token_latent = self.gen_diffusion.p_sample_loop(
49
- sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs, progress=False,
50
- temperature=temperature
51
- )
52
-
53
- return sampled_token_latent
54
-
55
-
56
- def modulate(x, shift, scale):
57
- return x * (1 + scale) + shift
58
-
59
-
60
- class TimestepEmbedder(nn.Module):
61
- """
62
- Embeds scalar timesteps into vector representations.
63
- """
64
- def __init__(self, hidden_size, frequency_embedding_size=256):
65
- super().__init__()
66
- self.mlp = nn.Sequential(
67
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
68
- nn.SiLU(),
69
- nn.Linear(hidden_size, hidden_size, bias=True),
70
- )
71
- self.frequency_embedding_size = frequency_embedding_size
72
-
73
- @staticmethod
74
- def timestep_embedding(t, dim, max_period=10000):
75
- """
76
- Create sinusoidal timestep embeddings.
77
- :param t: a 1-D Tensor of N indices, one per batch element.
78
- These may be fractional.
79
- :param dim: the dimension of the output.
80
- :param max_period: controls the minimum frequency of the embeddings.
81
- :return: an (N, D) Tensor of positional embeddings.
82
- """
83
-
84
- half = dim // 2
85
- freqs = torch.exp(
86
- -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
87
- ).to(device=t.device)
88
- args = t[:, None].float() * freqs[None]
89
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
90
- if dim % 2:
91
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
92
- return embedding
93
-
94
- def forward(self, t):
95
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
96
- t_emb = self.mlp(t_freq)
97
- return t_emb
98
-
99
-
100
- class ResBlock(nn.Module):
101
- """
102
- A residual block that can optionally change the number of channels.
103
- :param channels: the number of input channels.
104
- """
105
-
106
- def __init__(
107
- self,
108
- channels
109
- ):
110
- super().__init__()
111
- self.channels = channels
112
-
113
- self.in_ln = nn.LayerNorm(channels, eps=1e-6)
114
-
115
- self.mlp = nn.Sequential(
116
- nn.Linear(channels, channels, bias=True),
117
- nn.SiLU(),
118
- nn.Linear(channels, channels, bias=True),
119
- )
120
-
121
-
122
- self.adaLN_modulation = nn.Sequential(
123
- nn.SiLU(),
124
- nn.Linear(channels, 3 * channels, bias=True)
125
- )
126
-
127
- def forward(self, x, y):
128
- shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
129
- h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
130
- h = self.mlp(h)
131
- return x + gate_mlp * h
132
-
133
-
134
- class FinalLayer(nn.Module):
135
- """
136
- The final layer adopted from DiT.
137
- """
138
- def __init__(self, model_channels, out_channels):
139
- super().__init__()
140
- self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
141
- self.linear = nn.Linear(model_channels, out_channels, bias=True)
142
-
143
- self.adaLN_modulation = nn.Sequential(
144
- nn.SiLU(),
145
- nn.Linear(model_channels, 2 * model_channels, bias=True)
146
- )
147
-
148
- def forward(self, x, c):
149
- shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
150
- x = modulate(self.norm_final(x), shift, scale)
151
- x = self.linear(x)
152
- return x
153
-
154
-
155
- class SimpleMLPAdaLN(nn.Module):
156
- """
157
- The MLP for Diffusion Loss.
158
- :param in_channels: channels in the input Tensor.
159
- :param model_channels: base channel count for the model.
160
- :param out_channels: channels in the output Tensor.
161
- :param z_channels: channels in the condition.
162
- :param num_res_blocks: number of residual blocks per downsample.
163
- """
164
-
165
- def __init__(
166
- self,
167
- in_channels,
168
- model_channels,
169
- out_channels,
170
- z_channels,
171
- num_res_blocks,
172
- grad_checkpointing=False
173
- ):
174
- super().__init__()
175
-
176
- self.in_channels = in_channels
177
- self.model_channels = model_channels
178
- self.out_channels = out_channels
179
- self.num_res_blocks = num_res_blocks
180
- self.grad_checkpointing = grad_checkpointing
181
-
182
- self.time_embed = TimestepEmbedder(model_channels)
183
- self.cond_embed = nn.Linear(z_channels, model_channels)
184
-
185
- self.input_proj = nn.Linear(in_channels, model_channels)
186
-
187
- res_blocks = []
188
- for i in range(num_res_blocks):
189
- res_blocks.append(ResBlock(
190
- model_channels
191
- ))
192
-
193
- self.res_blocks = nn.ModuleList(res_blocks)
194
- self.final_layer = FinalLayer(model_channels, out_channels)
195
-
196
- self.initialize_weights()
197
-
198
- def initialize_weights(self):
199
- def _basic_init(module):
200
- if isinstance(module, nn.Linear):
201
- torch.nn.init.xavier_uniform_(module.weight)
202
- if module.bias is not None:
203
- nn.init.constant_(module.bias, 0)
204
- self.apply(_basic_init)
205
-
206
- # Initialize timestep embedding MLP
207
- nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
208
- nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
209
-
210
- # Zero-out adaLN modulation layers
211
-
212
- for block in self.res_blocks:
213
- nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
214
- nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
215
-
216
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
217
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
218
- nn.init.constant_(self.final_layer.linear.weight, 0)
219
- nn.init.constant_(self.final_layer.linear.bias, 0)
220
-
221
- def forward(self, x, t, c):
222
- """
223
- Apply the model to an input batch.
224
- :param x: an [N x C] Tensor of inputs.
225
- :param t: a 1-D batch of timesteps.
226
- :param c: conditioning from AR transformer.
227
- :return: an [N x C] Tensor of outputs.
228
- """
229
-
230
-
231
-
232
- x = x.float()
233
-
234
- x = self.input_proj(x)
235
- t = self.time_embed(t)
236
- c = self.cond_embed(c)
237
-
238
-
239
- y = t + c
240
-
241
- if self.grad_checkpointing and not torch.jit.is_scripting():
242
- for block in self.res_blocks:
243
- x = checkpoint(block, x, y)
244
- else:
245
- for block in self.res_blocks:
246
- x = block(x, y)
247
-
248
- return self.final_layer(x, y)
249
-
250
- def forward_with_cfg(self, x, t, c, cfg_scale):
251
- half = x[: len(x) // 2]
252
- combined = torch.cat([half, half], dim=0)
253
- model_out = self.forward(combined, t, c)
254
- eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
255
- cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
256
- half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
257
- eps = torch.cat([half_eps, half_eps], dim=0)
258
- return torch.cat([eps, rest], dim=1)