ak36 commited on
Commit
ec85c56
·
verified ·
1 Parent(s): ac3c1fa

Upload folder using huggingface_hub

Browse files
Modules/diffusion/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
Modules/diffusion/diffusion.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import pi
2
+ from random import randint
3
+ from typing import Any, Optional, Sequence, Tuple, Union
4
+
5
+ import torch
6
+ from einops import rearrange
7
+ from torch import Tensor, nn
8
+ from tqdm import tqdm
9
+
10
+ from .utils import *
11
+ from .sampler import *
12
+
13
+ """
14
+ Diffusion Classes (generic for 1d data)
15
+ """
16
+
17
+
18
+ class Model1d(nn.Module):
19
+ def __init__(self, unet_type: str = "base", **kwargs):
20
+ super().__init__()
21
+ diffusion_kwargs, kwargs = groupby("diffusion_", kwargs)
22
+ self.unet = None
23
+ self.diffusion = None
24
+
25
+ def forward(self, x: Tensor, **kwargs) -> Tensor:
26
+ return self.diffusion(x, **kwargs)
27
+
28
+ def sample(self, *args, **kwargs) -> Tensor:
29
+ return self.diffusion.sample(*args, **kwargs)
30
+
31
+
32
+ """
33
+ Audio Diffusion Classes (specific for 1d audio data)
34
+ """
35
+
36
+
37
+ def get_default_model_kwargs():
38
+ return dict(
39
+ channels=128,
40
+ patch_size=16,
41
+ multipliers=[1, 2, 4, 4, 4, 4, 4],
42
+ factors=[4, 4, 4, 2, 2, 2],
43
+ num_blocks=[2, 2, 2, 2, 2, 2],
44
+ attentions=[0, 0, 0, 1, 1, 1, 1],
45
+ attention_heads=8,
46
+ attention_features=64,
47
+ attention_multiplier=2,
48
+ attention_use_rel_pos=False,
49
+ diffusion_type="v",
50
+ diffusion_sigma_distribution=UniformDistribution(),
51
+ )
52
+
53
+
54
+ def get_default_sampling_kwargs():
55
+ return dict(sigma_schedule=LinearSchedule(), sampler=VSampler(), clamp=True)
56
+
57
+
58
+ class AudioDiffusionModel(Model1d):
59
+ def __init__(self, **kwargs):
60
+ super().__init__(**{**get_default_model_kwargs(), **kwargs})
61
+
62
+ def sample(self, *args, **kwargs):
63
+ return super().sample(*args, **{**get_default_sampling_kwargs(), **kwargs})
64
+
65
+
66
+ class AudioDiffusionConditional(Model1d):
67
+ def __init__(
68
+ self,
69
+ embedding_features: int,
70
+ embedding_max_length: int,
71
+ embedding_mask_proba: float = 0.1,
72
+ **kwargs,
73
+ ):
74
+ self.embedding_mask_proba = embedding_mask_proba
75
+ default_kwargs = dict(
76
+ **get_default_model_kwargs(),
77
+ unet_type="cfg",
78
+ context_embedding_features=embedding_features,
79
+ context_embedding_max_length=embedding_max_length,
80
+ )
81
+ super().__init__(**{**default_kwargs, **kwargs})
82
+
83
+ def forward(self, *args, **kwargs):
84
+ default_kwargs = dict(embedding_mask_proba=self.embedding_mask_proba)
85
+ return super().forward(*args, **{**default_kwargs, **kwargs})
86
+
87
+ def sample(self, *args, **kwargs):
88
+ default_kwargs = dict(
89
+ **get_default_sampling_kwargs(),
90
+ embedding_scale=5.0,
91
+ )
92
+ return super().sample(*args, **{**default_kwargs, **kwargs})
93
+
94
+
Modules/diffusion/modules.py ADDED
@@ -0,0 +1,693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import floor, log, pi
2
+ from typing import Any, List, Optional, Sequence, Tuple, Union
3
+
4
+ from .utils import *
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange, reduce, repeat
9
+ from einops.layers.torch import Rearrange
10
+ from einops_exts import rearrange_many
11
+ from torch import Tensor, einsum
12
+
13
+
14
+ """
15
+ Utils
16
+ """
17
+
18
+ class AdaLayerNorm(nn.Module):
19
+ def __init__(self, style_dim, channels, eps=1e-5):
20
+ super().__init__()
21
+ self.channels = channels
22
+ self.eps = eps
23
+
24
+ self.fc = nn.Linear(style_dim, channels*2)
25
+
26
+ def forward(self, x, s):
27
+ x = x.transpose(-1, -2)
28
+ x = x.transpose(1, -1)
29
+
30
+ h = self.fc(s)
31
+ h = h.view(h.size(0), h.size(1), 1)
32
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
33
+ gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
34
+
35
+
36
+ x = F.layer_norm(x, (self.channels,), eps=self.eps)
37
+ x = (1 + gamma) * x + beta
38
+ return x.transpose(1, -1).transpose(-1, -2)
39
+
40
+ class StyleTransformer1d(nn.Module):
41
+ def __init__(
42
+ self,
43
+ num_layers: int,
44
+ channels: int,
45
+ num_heads: int,
46
+ head_features: int,
47
+ multiplier: int,
48
+ use_context_time: bool = True,
49
+ use_rel_pos: bool = False,
50
+ context_features_multiplier: int = 1,
51
+ rel_pos_num_buckets: Optional[int] = None,
52
+ rel_pos_max_distance: Optional[int] = None,
53
+ context_features: Optional[int] = None,
54
+ context_embedding_features: Optional[int] = None,
55
+ embedding_max_length: int = 512,
56
+ ):
57
+ super().__init__()
58
+
59
+ self.blocks = nn.ModuleList(
60
+ [
61
+ StyleTransformerBlock(
62
+ features=channels + context_embedding_features,
63
+ head_features=head_features,
64
+ num_heads=num_heads,
65
+ multiplier=multiplier,
66
+ style_dim=context_features,
67
+ use_rel_pos=use_rel_pos,
68
+ rel_pos_num_buckets=rel_pos_num_buckets,
69
+ rel_pos_max_distance=rel_pos_max_distance,
70
+ )
71
+ for i in range(num_layers)
72
+ ]
73
+ )
74
+
75
+ self.to_out = nn.Sequential(
76
+ Rearrange("b t c -> b c t"),
77
+ nn.Conv1d(
78
+ in_channels=channels + context_embedding_features,
79
+ out_channels=channels,
80
+ kernel_size=1,
81
+ ),
82
+ )
83
+
84
+ use_context_features = exists(context_features)
85
+ self.use_context_features = use_context_features
86
+ self.use_context_time = use_context_time
87
+
88
+ if use_context_time or use_context_features:
89
+ context_mapping_features = channels + context_embedding_features
90
+
91
+ self.to_mapping = nn.Sequential(
92
+ nn.Linear(context_mapping_features, context_mapping_features),
93
+ nn.GELU(),
94
+ nn.Linear(context_mapping_features, context_mapping_features),
95
+ nn.GELU(),
96
+ )
97
+
98
+ if use_context_time:
99
+ assert exists(context_mapping_features)
100
+ self.to_time = nn.Sequential(
101
+ TimePositionalEmbedding(
102
+ dim=channels, out_features=context_mapping_features
103
+ ),
104
+ nn.GELU(),
105
+ )
106
+
107
+ if use_context_features:
108
+ assert exists(context_features) and exists(context_mapping_features)
109
+ self.to_features = nn.Sequential(
110
+ nn.Linear(
111
+ in_features=context_features, out_features=context_mapping_features
112
+ ),
113
+ nn.GELU(),
114
+ )
115
+
116
+ self.fixed_embedding = FixedEmbedding(
117
+ max_length=embedding_max_length, features=context_embedding_features
118
+ )
119
+
120
+
121
+ def get_mapping(
122
+ self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
123
+ ) -> Optional[Tensor]:
124
+ """Combines context time features and features into mapping"""
125
+ items, mapping = [], None
126
+ # Compute time features
127
+ if self.use_context_time:
128
+ assert_message = "use_context_time=True but no time features provided"
129
+ assert exists(time), assert_message
130
+ items += [self.to_time(time)]
131
+ # Compute features
132
+ if self.use_context_features:
133
+ assert_message = "context_features exists but no features provided"
134
+ assert exists(features), assert_message
135
+ items += [self.to_features(features)]
136
+
137
+ # Compute joint mapping
138
+ if self.use_context_time or self.use_context_features:
139
+ mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
140
+ mapping = self.to_mapping(mapping)
141
+
142
+ return mapping
143
+
144
+ def run(self, x, time, embedding, features):
145
+
146
+ mapping = self.get_mapping(time, features)
147
+ x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
148
+ mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
149
+
150
+ for block in self.blocks:
151
+ x = x + mapping
152
+ x = block(x, features)
153
+
154
+ x = x.mean(axis=1).unsqueeze(1)
155
+ x = self.to_out(x)
156
+ x = x.transpose(-1, -2)
157
+
158
+ return x
159
+
160
+ def forward(self, x: Tensor,
161
+ time: Tensor,
162
+ embedding_mask_proba: float = 0.0,
163
+ embedding: Optional[Tensor] = None,
164
+ features: Optional[Tensor] = None,
165
+ embedding_scale: float = 1.0) -> Tensor:
166
+
167
+ b, device = embedding.shape[0], embedding.device
168
+ fixed_embedding = self.fixed_embedding(embedding)
169
+ if embedding_mask_proba > 0.0:
170
+ # Randomly mask embedding
171
+ batch_mask = rand_bool(
172
+ shape=(b, 1, 1), proba=embedding_mask_proba, device=device
173
+ )
174
+ embedding = torch.where(batch_mask, fixed_embedding, embedding)
175
+
176
+ if embedding_scale != 1.0:
177
+ # Compute both normal and fixed embedding outputs
178
+ out = self.run(x, time, embedding=embedding, features=features)
179
+ out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
180
+ # Scale conditional output using classifier-free guidance
181
+ return out_masked + (out - out_masked) * embedding_scale
182
+ else:
183
+ return self.run(x, time, embedding=embedding, features=features)
184
+
185
+ return x
186
+
187
+
188
+ class StyleTransformerBlock(nn.Module):
189
+ def __init__(
190
+ self,
191
+ features: int,
192
+ num_heads: int,
193
+ head_features: int,
194
+ style_dim: int,
195
+ multiplier: int,
196
+ use_rel_pos: bool,
197
+ rel_pos_num_buckets: Optional[int] = None,
198
+ rel_pos_max_distance: Optional[int] = None,
199
+ context_features: Optional[int] = None,
200
+ ):
201
+ super().__init__()
202
+
203
+ self.use_cross_attention = exists(context_features) and context_features > 0
204
+
205
+ self.attention = StyleAttention(
206
+ features=features,
207
+ style_dim=style_dim,
208
+ num_heads=num_heads,
209
+ head_features=head_features,
210
+ use_rel_pos=use_rel_pos,
211
+ rel_pos_num_buckets=rel_pos_num_buckets,
212
+ rel_pos_max_distance=rel_pos_max_distance,
213
+ )
214
+
215
+ if self.use_cross_attention:
216
+ self.cross_attention = StyleAttention(
217
+ features=features,
218
+ style_dim=style_dim,
219
+ num_heads=num_heads,
220
+ head_features=head_features,
221
+ context_features=context_features,
222
+ use_rel_pos=use_rel_pos,
223
+ rel_pos_num_buckets=rel_pos_num_buckets,
224
+ rel_pos_max_distance=rel_pos_max_distance,
225
+ )
226
+
227
+ self.feed_forward = FeedForward(features=features, multiplier=multiplier)
228
+
229
+ def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
230
+ x = self.attention(x, s) + x
231
+ if self.use_cross_attention:
232
+ x = self.cross_attention(x, s, context=context) + x
233
+ x = self.feed_forward(x) + x
234
+ return x
235
+
236
+ class StyleAttention(nn.Module):
237
+ def __init__(
238
+ self,
239
+ features: int,
240
+ *,
241
+ style_dim: int,
242
+ head_features: int,
243
+ num_heads: int,
244
+ context_features: Optional[int] = None,
245
+ use_rel_pos: bool,
246
+ rel_pos_num_buckets: Optional[int] = None,
247
+ rel_pos_max_distance: Optional[int] = None,
248
+ ):
249
+ super().__init__()
250
+ self.context_features = context_features
251
+ mid_features = head_features * num_heads
252
+ context_features = default(context_features, features)
253
+
254
+ self.norm = AdaLayerNorm(style_dim, features)
255
+ self.norm_context = AdaLayerNorm(style_dim, context_features)
256
+ self.to_q = nn.Linear(
257
+ in_features=features, out_features=mid_features, bias=False
258
+ )
259
+ self.to_kv = nn.Linear(
260
+ in_features=context_features, out_features=mid_features * 2, bias=False
261
+ )
262
+ self.attention = AttentionBase(
263
+ features,
264
+ num_heads=num_heads,
265
+ head_features=head_features,
266
+ use_rel_pos=use_rel_pos,
267
+ rel_pos_num_buckets=rel_pos_num_buckets,
268
+ rel_pos_max_distance=rel_pos_max_distance,
269
+ )
270
+
271
+ def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
272
+ assert_message = "You must provide a context when using context_features"
273
+ assert not self.context_features or exists(context), assert_message
274
+ # Use context if provided
275
+ context = default(context, x)
276
+ # Normalize then compute q from input and k,v from context
277
+ x, context = self.norm(x, s), self.norm_context(context, s)
278
+
279
+ q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
280
+ # Compute and return attention
281
+ return self.attention(q, k, v)
282
+
283
+ class Transformer1d(nn.Module):
284
+ def __init__(
285
+ self,
286
+ num_layers: int,
287
+ channels: int,
288
+ num_heads: int,
289
+ head_features: int,
290
+ multiplier: int,
291
+ use_context_time: bool = True,
292
+ use_rel_pos: bool = False,
293
+ context_features_multiplier: int = 1,
294
+ rel_pos_num_buckets: Optional[int] = None,
295
+ rel_pos_max_distance: Optional[int] = None,
296
+ context_features: Optional[int] = None,
297
+ context_embedding_features: Optional[int] = None,
298
+ embedding_max_length: int = 512,
299
+ ):
300
+ super().__init__()
301
+
302
+ self.blocks = nn.ModuleList(
303
+ [
304
+ TransformerBlock(
305
+ features=channels + context_embedding_features,
306
+ head_features=head_features,
307
+ num_heads=num_heads,
308
+ multiplier=multiplier,
309
+ use_rel_pos=use_rel_pos,
310
+ rel_pos_num_buckets=rel_pos_num_buckets,
311
+ rel_pos_max_distance=rel_pos_max_distance,
312
+ )
313
+ for i in range(num_layers)
314
+ ]
315
+ )
316
+
317
+ self.to_out = nn.Sequential(
318
+ Rearrange("b t c -> b c t"),
319
+ nn.Conv1d(
320
+ in_channels=channels + context_embedding_features,
321
+ out_channels=channels,
322
+ kernel_size=1,
323
+ ),
324
+ )
325
+
326
+ use_context_features = exists(context_features)
327
+ self.use_context_features = use_context_features
328
+ self.use_context_time = use_context_time
329
+
330
+ if use_context_time or use_context_features:
331
+ context_mapping_features = channels + context_embedding_features
332
+
333
+ self.to_mapping = nn.Sequential(
334
+ nn.Linear(context_mapping_features, context_mapping_features),
335
+ nn.GELU(),
336
+ nn.Linear(context_mapping_features, context_mapping_features),
337
+ nn.GELU(),
338
+ )
339
+
340
+ if use_context_time:
341
+ assert exists(context_mapping_features)
342
+ self.to_time = nn.Sequential(
343
+ TimePositionalEmbedding(
344
+ dim=channels, out_features=context_mapping_features
345
+ ),
346
+ nn.GELU(),
347
+ )
348
+
349
+ if use_context_features:
350
+ assert exists(context_features) and exists(context_mapping_features)
351
+ self.to_features = nn.Sequential(
352
+ nn.Linear(
353
+ in_features=context_features, out_features=context_mapping_features
354
+ ),
355
+ nn.GELU(),
356
+ )
357
+
358
+ self.fixed_embedding = FixedEmbedding(
359
+ max_length=embedding_max_length, features=context_embedding_features
360
+ )
361
+
362
+
363
+ def get_mapping(
364
+ self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
365
+ ) -> Optional[Tensor]:
366
+ """Combines context time features and features into mapping"""
367
+ items, mapping = [], None
368
+ # Compute time features
369
+ if self.use_context_time:
370
+ assert_message = "use_context_time=True but no time features provided"
371
+ assert exists(time), assert_message
372
+ items += [self.to_time(time)]
373
+ # Compute features
374
+ if self.use_context_features:
375
+ assert_message = "context_features exists but no features provided"
376
+ assert exists(features), assert_message
377
+ items += [self.to_features(features)]
378
+
379
+ # Compute joint mapping
380
+ if self.use_context_time or self.use_context_features:
381
+ mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
382
+ mapping = self.to_mapping(mapping)
383
+
384
+ return mapping
385
+
386
+ def run(self, x, time, embedding, features):
387
+
388
+ mapping = self.get_mapping(time, features)
389
+ x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
390
+ mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
391
+
392
+ for block in self.blocks:
393
+ x = x + mapping
394
+ x = block(x)
395
+
396
+ x = x.mean(axis=1).unsqueeze(1)
397
+ x = self.to_out(x)
398
+ x = x.transpose(-1, -2)
399
+
400
+ return x
401
+
402
+ def forward(self, x: Tensor,
403
+ time: Tensor,
404
+ embedding_mask_proba: float = 0.0,
405
+ embedding: Optional[Tensor] = None,
406
+ features: Optional[Tensor] = None,
407
+ embedding_scale: float = 1.0) -> Tensor:
408
+
409
+ b, device = embedding.shape[0], embedding.device
410
+ fixed_embedding = self.fixed_embedding(embedding)
411
+ if embedding_mask_proba > 0.0:
412
+ # Randomly mask embedding
413
+ batch_mask = rand_bool(
414
+ shape=(b, 1, 1), proba=embedding_mask_proba, device=device
415
+ )
416
+ embedding = torch.where(batch_mask, fixed_embedding, embedding)
417
+
418
+ if embedding_scale != 1.0:
419
+ # Compute both normal and fixed embedding outputs
420
+ out = self.run(x, time, embedding=embedding, features=features)
421
+ out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
422
+ # Scale conditional output using classifier-free guidance
423
+ return out_masked + (out - out_masked) * embedding_scale
424
+ else:
425
+ return self.run(x, time, embedding=embedding, features=features)
426
+
427
+ return x
428
+
429
+
430
+ """
431
+ Attention Components
432
+ """
433
+
434
+
435
+ class RelativePositionBias(nn.Module):
436
+ def __init__(self, num_buckets: int, max_distance: int, num_heads: int):
437
+ super().__init__()
438
+ self.num_buckets = num_buckets
439
+ self.max_distance = max_distance
440
+ self.num_heads = num_heads
441
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
442
+
443
+ @staticmethod
444
+ def _relative_position_bucket(
445
+ relative_position: Tensor, num_buckets: int, max_distance: int
446
+ ):
447
+ num_buckets //= 2
448
+ ret = (relative_position >= 0).to(torch.long) * num_buckets
449
+ n = torch.abs(relative_position)
450
+
451
+ max_exact = num_buckets // 2
452
+ is_small = n < max_exact
453
+
454
+ val_if_large = (
455
+ max_exact
456
+ + (
457
+ torch.log(n.float() / max_exact)
458
+ / log(max_distance / max_exact)
459
+ * (num_buckets - max_exact)
460
+ ).long()
461
+ )
462
+ val_if_large = torch.min(
463
+ val_if_large, torch.full_like(val_if_large, num_buckets - 1)
464
+ )
465
+
466
+ ret += torch.where(is_small, n, val_if_large)
467
+ return ret
468
+
469
+ def forward(self, num_queries: int, num_keys: int) -> Tensor:
470
+ i, j, device = num_queries, num_keys, self.relative_attention_bias.weight.device
471
+ q_pos = torch.arange(j - i, j, dtype=torch.long, device=device)
472
+ k_pos = torch.arange(j, dtype=torch.long, device=device)
473
+ rel_pos = rearrange(k_pos, "j -> 1 j") - rearrange(q_pos, "i -> i 1")
474
+
475
+ relative_position_bucket = self._relative_position_bucket(
476
+ rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance
477
+ )
478
+
479
+ bias = self.relative_attention_bias(relative_position_bucket)
480
+ bias = rearrange(bias, "m n h -> 1 h m n")
481
+ return bias
482
+
483
+
484
+ def FeedForward(features: int, multiplier: int) -> nn.Module:
485
+ mid_features = features * multiplier
486
+ return nn.Sequential(
487
+ nn.Linear(in_features=features, out_features=mid_features),
488
+ nn.GELU(),
489
+ nn.Linear(in_features=mid_features, out_features=features),
490
+ )
491
+
492
+
493
+ class AttentionBase(nn.Module):
494
+ def __init__(
495
+ self,
496
+ features: int,
497
+ *,
498
+ head_features: int,
499
+ num_heads: int,
500
+ use_rel_pos: bool,
501
+ out_features: Optional[int] = None,
502
+ rel_pos_num_buckets: Optional[int] = None,
503
+ rel_pos_max_distance: Optional[int] = None,
504
+ ):
505
+ super().__init__()
506
+ self.scale = head_features ** -0.5
507
+ self.num_heads = num_heads
508
+ self.use_rel_pos = use_rel_pos
509
+ mid_features = head_features * num_heads
510
+
511
+ if use_rel_pos:
512
+ assert exists(rel_pos_num_buckets) and exists(rel_pos_max_distance)
513
+ self.rel_pos = RelativePositionBias(
514
+ num_buckets=rel_pos_num_buckets,
515
+ max_distance=rel_pos_max_distance,
516
+ num_heads=num_heads,
517
+ )
518
+ if out_features is None:
519
+ out_features = features
520
+
521
+ self.to_out = nn.Linear(in_features=mid_features, out_features=out_features)
522
+
523
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
524
+ # Split heads
525
+ q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
526
+ # Compute similarity matrix
527
+ sim = einsum("... n d, ... m d -> ... n m", q, k)
528
+ sim = (sim + self.rel_pos(*sim.shape[-2:])) if self.use_rel_pos else sim
529
+ sim = sim * self.scale
530
+ # Get attention matrix with softmax
531
+ attn = sim.softmax(dim=-1)
532
+ # Compute values
533
+ out = einsum("... n m, ... m d -> ... n d", attn, v)
534
+ out = rearrange(out, "b h n d -> b n (h d)")
535
+ return self.to_out(out)
536
+
537
+
538
+ class Attention(nn.Module):
539
+ def __init__(
540
+ self,
541
+ features: int,
542
+ *,
543
+ head_features: int,
544
+ num_heads: int,
545
+ out_features: Optional[int] = None,
546
+ context_features: Optional[int] = None,
547
+ use_rel_pos: bool,
548
+ rel_pos_num_buckets: Optional[int] = None,
549
+ rel_pos_max_distance: Optional[int] = None,
550
+ ):
551
+ super().__init__()
552
+ self.context_features = context_features
553
+ mid_features = head_features * num_heads
554
+ context_features = default(context_features, features)
555
+
556
+ self.norm = nn.LayerNorm(features)
557
+ self.norm_context = nn.LayerNorm(context_features)
558
+ self.to_q = nn.Linear(
559
+ in_features=features, out_features=mid_features, bias=False
560
+ )
561
+ self.to_kv = nn.Linear(
562
+ in_features=context_features, out_features=mid_features * 2, bias=False
563
+ )
564
+
565
+ self.attention = AttentionBase(
566
+ features,
567
+ out_features=out_features,
568
+ num_heads=num_heads,
569
+ head_features=head_features,
570
+ use_rel_pos=use_rel_pos,
571
+ rel_pos_num_buckets=rel_pos_num_buckets,
572
+ rel_pos_max_distance=rel_pos_max_distance,
573
+ )
574
+
575
+ def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
576
+ assert_message = "You must provide a context when using context_features"
577
+ assert not self.context_features or exists(context), assert_message
578
+ # Use context if provided
579
+ context = default(context, x)
580
+ # Normalize then compute q from input and k,v from context
581
+ x, context = self.norm(x), self.norm_context(context)
582
+ q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
583
+ # Compute and return attention
584
+ return self.attention(q, k, v)
585
+
586
+
587
+ """
588
+ Transformer Blocks
589
+ """
590
+
591
+
592
+ class TransformerBlock(nn.Module):
593
+ def __init__(
594
+ self,
595
+ features: int,
596
+ num_heads: int,
597
+ head_features: int,
598
+ multiplier: int,
599
+ use_rel_pos: bool,
600
+ rel_pos_num_buckets: Optional[int] = None,
601
+ rel_pos_max_distance: Optional[int] = None,
602
+ context_features: Optional[int] = None,
603
+ ):
604
+ super().__init__()
605
+
606
+ self.use_cross_attention = exists(context_features) and context_features > 0
607
+
608
+ self.attention = Attention(
609
+ features=features,
610
+ num_heads=num_heads,
611
+ head_features=head_features,
612
+ use_rel_pos=use_rel_pos,
613
+ rel_pos_num_buckets=rel_pos_num_buckets,
614
+ rel_pos_max_distance=rel_pos_max_distance,
615
+ )
616
+
617
+ if self.use_cross_attention:
618
+ self.cross_attention = Attention(
619
+ features=features,
620
+ num_heads=num_heads,
621
+ head_features=head_features,
622
+ context_features=context_features,
623
+ use_rel_pos=use_rel_pos,
624
+ rel_pos_num_buckets=rel_pos_num_buckets,
625
+ rel_pos_max_distance=rel_pos_max_distance,
626
+ )
627
+
628
+ self.feed_forward = FeedForward(features=features, multiplier=multiplier)
629
+
630
+ def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
631
+ x = self.attention(x) + x
632
+ if self.use_cross_attention:
633
+ x = self.cross_attention(x, context=context) + x
634
+ x = self.feed_forward(x) + x
635
+ return x
636
+
637
+
638
+
639
+ """
640
+ Time Embeddings
641
+ """
642
+
643
+
644
+ class SinusoidalEmbedding(nn.Module):
645
+ def __init__(self, dim: int):
646
+ super().__init__()
647
+ self.dim = dim
648
+
649
+ def forward(self, x: Tensor) -> Tensor:
650
+ device, half_dim = x.device, self.dim // 2
651
+ emb = torch.tensor(log(10000) / (half_dim - 1), device=device)
652
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
653
+ emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
654
+ return torch.cat((emb.sin(), emb.cos()), dim=-1)
655
+
656
+
657
+ class LearnedPositionalEmbedding(nn.Module):
658
+ """Used for continuous time"""
659
+
660
+ def __init__(self, dim: int):
661
+ super().__init__()
662
+ assert (dim % 2) == 0
663
+ half_dim = dim // 2
664
+ self.weights = nn.Parameter(torch.randn(half_dim))
665
+
666
+ def forward(self, x: Tensor) -> Tensor:
667
+ x = rearrange(x, "b -> b 1")
668
+ freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
669
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
670
+ fouriered = torch.cat((x, fouriered), dim=-1)
671
+ return fouriered
672
+
673
+
674
+ def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
675
+ return nn.Sequential(
676
+ LearnedPositionalEmbedding(dim),
677
+ nn.Linear(in_features=dim + 1, out_features=out_features),
678
+ )
679
+
680
+ class FixedEmbedding(nn.Module):
681
+ def __init__(self, max_length: int, features: int):
682
+ super().__init__()
683
+ self.max_length = max_length
684
+ self.embedding = nn.Embedding(max_length, features)
685
+
686
+ def forward(self, x: Tensor) -> Tensor:
687
+ batch_size, length, device = *x.shape[0:2], x.device
688
+ assert_message = "Input sequence length must be <= max_length"
689
+ assert length <= self.max_length, assert_message
690
+ position = torch.arange(length, device=device)
691
+ fixed_embedding = self.embedding(position)
692
+ fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size)
693
+ return fixed_embedding
Modules/diffusion/sampler.py ADDED
@@ -0,0 +1,691 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import atan, cos, pi, sin, sqrt
2
+ from typing import Any, Callable, List, Optional, Tuple, Type
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange, reduce
8
+ from torch import Tensor
9
+
10
+ from .utils import *
11
+
12
+ """
13
+ Diffusion Training
14
+ """
15
+
16
+ """ Distributions """
17
+
18
+
19
+ class Distribution:
20
+ def __call__(self, num_samples: int, device: torch.device):
21
+ raise NotImplementedError()
22
+
23
+
24
+ class LogNormalDistribution(Distribution):
25
+ def __init__(self, mean: float, std: float):
26
+ self.mean = mean
27
+ self.std = std
28
+
29
+ def __call__(
30
+ self, num_samples: int, device: torch.device = torch.device("cpu")
31
+ ) -> Tensor:
32
+ normal = self.mean + self.std * torch.randn((num_samples,), device=device)
33
+ return normal.exp()
34
+
35
+
36
+ class UniformDistribution(Distribution):
37
+ def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
38
+ return torch.rand(num_samples, device=device)
39
+
40
+
41
+ class VKDistribution(Distribution):
42
+ def __init__(
43
+ self,
44
+ min_value: float = 0.0,
45
+ max_value: float = float("inf"),
46
+ sigma_data: float = 1.0,
47
+ ):
48
+ self.min_value = min_value
49
+ self.max_value = max_value
50
+ self.sigma_data = sigma_data
51
+
52
+ def __call__(
53
+ self, num_samples: int, device: torch.device = torch.device("cpu")
54
+ ) -> Tensor:
55
+ sigma_data = self.sigma_data
56
+ min_cdf = atan(self.min_value / sigma_data) * 2 / pi
57
+ max_cdf = atan(self.max_value / sigma_data) * 2 / pi
58
+ u = (max_cdf - min_cdf) * torch.randn((num_samples,), device=device) + min_cdf
59
+ return torch.tan(u * pi / 2) * sigma_data
60
+
61
+
62
+ """ Diffusion Classes """
63
+
64
+
65
+ def pad_dims(x: Tensor, ndim: int) -> Tensor:
66
+ # Pads additional ndims to the right of the tensor
67
+ return x.view(*x.shape, *((1,) * ndim))
68
+
69
+
70
+ def clip(x: Tensor, dynamic_threshold: float = 0.0):
71
+ if dynamic_threshold == 0.0:
72
+ return x.clamp(-1.0, 1.0)
73
+ else:
74
+ # Dynamic thresholding
75
+ # Find dynamic threshold quantile for each batch
76
+ x_flat = rearrange(x, "b ... -> b (...)")
77
+ scale = torch.quantile(x_flat.abs(), dynamic_threshold, dim=-1)
78
+ # Clamp to a min of 1.0
79
+ scale.clamp_(min=1.0)
80
+ # Clamp all values and scale
81
+ scale = pad_dims(scale, ndim=x.ndim - scale.ndim)
82
+ x = x.clamp(-scale, scale) / scale
83
+ return x
84
+
85
+
86
+ def to_batch(
87
+ batch_size: int,
88
+ device: torch.device,
89
+ x: Optional[float] = None,
90
+ xs: Optional[Tensor] = None,
91
+ ) -> Tensor:
92
+ assert exists(x) ^ exists(xs), "Either x or xs must be provided"
93
+ # If x provided use the same for all batch items
94
+ if exists(x):
95
+ xs = torch.full(size=(batch_size,), fill_value=x).to(device)
96
+ assert exists(xs)
97
+ return xs
98
+
99
+
100
+ class Diffusion(nn.Module):
101
+
102
+ alias: str = ""
103
+
104
+ """Base diffusion class"""
105
+
106
+ def denoise_fn(
107
+ self,
108
+ x_noisy: Tensor,
109
+ sigmas: Optional[Tensor] = None,
110
+ sigma: Optional[float] = None,
111
+ **kwargs,
112
+ ) -> Tensor:
113
+ raise NotImplementedError("Diffusion class missing denoise_fn")
114
+
115
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
116
+ raise NotImplementedError("Diffusion class missing forward function")
117
+
118
+
119
+ class VDiffusion(Diffusion):
120
+
121
+ alias = "v"
122
+
123
+ def __init__(self, net: nn.Module, *, sigma_distribution: Distribution):
124
+ super().__init__()
125
+ self.net = net
126
+ self.sigma_distribution = sigma_distribution
127
+
128
+ def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
129
+ angle = sigmas * pi / 2
130
+ alpha = torch.cos(angle)
131
+ beta = torch.sin(angle)
132
+ return alpha, beta
133
+
134
+ def denoise_fn(
135
+ self,
136
+ x_noisy: Tensor,
137
+ sigmas: Optional[Tensor] = None,
138
+ sigma: Optional[float] = None,
139
+ **kwargs,
140
+ ) -> Tensor:
141
+ batch_size, device = x_noisy.shape[0], x_noisy.device
142
+ sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
143
+ return self.net(x_noisy, sigmas, **kwargs)
144
+
145
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
146
+ batch_size, device = x.shape[0], x.device
147
+
148
+ # Sample amount of noise to add for each batch element
149
+ sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
150
+ sigmas_padded = rearrange(sigmas, "b -> b 1 1")
151
+
152
+ # Get noise
153
+ noise = default(noise, lambda: torch.randn_like(x))
154
+
155
+ # Combine input and noise weighted by half-circle
156
+ alpha, beta = self.get_alpha_beta(sigmas_padded)
157
+ x_noisy = x * alpha + noise * beta
158
+ x_target = noise * alpha - x * beta
159
+
160
+ # Denoise and return loss
161
+ x_denoised = self.denoise_fn(x_noisy, sigmas, **kwargs)
162
+ return F.mse_loss(x_denoised, x_target)
163
+
164
+
165
+ class KDiffusion(Diffusion):
166
+ """Elucidated Diffusion (Karras et al. 2022): https://arxiv.org/abs/2206.00364"""
167
+
168
+ alias = "k"
169
+
170
+ def __init__(
171
+ self,
172
+ net: nn.Module,
173
+ *,
174
+ sigma_distribution: Distribution,
175
+ sigma_data: float, # data distribution standard deviation
176
+ dynamic_threshold: float = 0.0,
177
+ ):
178
+ super().__init__()
179
+ self.net = net
180
+ self.sigma_data = sigma_data
181
+ self.sigma_distribution = sigma_distribution
182
+ self.dynamic_threshold = dynamic_threshold
183
+
184
+ def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
185
+ sigma_data = self.sigma_data
186
+ c_noise = torch.log(sigmas) * 0.25
187
+ sigmas = rearrange(sigmas, "b -> b 1 1")
188
+ c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2)
189
+ c_out = sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5
190
+ c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5
191
+ return c_skip, c_out, c_in, c_noise
192
+
193
+ def denoise_fn(
194
+ self,
195
+ x_noisy: Tensor,
196
+ sigmas: Optional[Tensor] = None,
197
+ sigma: Optional[float] = None,
198
+ **kwargs,
199
+ ) -> Tensor:
200
+ batch_size, device = x_noisy.shape[0], x_noisy.device
201
+ sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
202
+
203
+ # Predict network output and add skip connection
204
+ c_skip, c_out, c_in, c_noise = self.get_scale_weights(sigmas)
205
+ x_pred = self.net(c_in * x_noisy, c_noise, **kwargs)
206
+ x_denoised = c_skip * x_noisy + c_out * x_pred
207
+
208
+ return x_denoised
209
+
210
+ def loss_weight(self, sigmas: Tensor) -> Tensor:
211
+ # Computes weight depending on data distribution
212
+ return (sigmas ** 2 + self.sigma_data ** 2) * (sigmas * self.sigma_data) ** -2
213
+
214
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
215
+ batch_size, device = x.shape[0], x.device
216
+ from einops import rearrange, reduce
217
+
218
+ # Sample amount of noise to add for each batch element
219
+ sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
220
+ sigmas_padded = rearrange(sigmas, "b -> b 1 1")
221
+
222
+ # Add noise to input
223
+ noise = default(noise, lambda: torch.randn_like(x))
224
+ x_noisy = x + sigmas_padded * noise
225
+
226
+ # Compute denoised values
227
+ x_denoised = self.denoise_fn(x_noisy, sigmas=sigmas, **kwargs)
228
+
229
+ # Compute weighted loss
230
+ losses = F.mse_loss(x_denoised, x, reduction="none")
231
+ losses = reduce(losses, "b ... -> b", "mean")
232
+ losses = losses * self.loss_weight(sigmas)
233
+ loss = losses.mean()
234
+ return loss
235
+
236
+
237
+ class VKDiffusion(Diffusion):
238
+
239
+ alias = "vk"
240
+
241
+ def __init__(self, net: nn.Module, *, sigma_distribution: Distribution):
242
+ super().__init__()
243
+ self.net = net
244
+ self.sigma_distribution = sigma_distribution
245
+
246
+ def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
247
+ sigma_data = 1.0
248
+ sigmas = rearrange(sigmas, "b -> b 1 1")
249
+ c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2)
250
+ c_out = -sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5
251
+ c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5
252
+ return c_skip, c_out, c_in
253
+
254
+ def sigma_to_t(self, sigmas: Tensor) -> Tensor:
255
+ return sigmas.atan() / pi * 2
256
+
257
+ def t_to_sigma(self, t: Tensor) -> Tensor:
258
+ return (t * pi / 2).tan()
259
+
260
+ def denoise_fn(
261
+ self,
262
+ x_noisy: Tensor,
263
+ sigmas: Optional[Tensor] = None,
264
+ sigma: Optional[float] = None,
265
+ **kwargs,
266
+ ) -> Tensor:
267
+ batch_size, device = x_noisy.shape[0], x_noisy.device
268
+ sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
269
+
270
+ # Predict network output and add skip connection
271
+ c_skip, c_out, c_in = self.get_scale_weights(sigmas)
272
+ x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)
273
+ x_denoised = c_skip * x_noisy + c_out * x_pred
274
+ return x_denoised
275
+
276
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
277
+ batch_size, device = x.shape[0], x.device
278
+
279
+ # Sample amount of noise to add for each batch element
280
+ sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
281
+ sigmas_padded = rearrange(sigmas, "b -> b 1 1")
282
+
283
+ # Add noise to input
284
+ noise = default(noise, lambda: torch.randn_like(x))
285
+ x_noisy = x + sigmas_padded * noise
286
+
287
+ # Compute model output
288
+ c_skip, c_out, c_in = self.get_scale_weights(sigmas)
289
+ x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)
290
+
291
+ # Compute v-objective target
292
+ v_target = (x - c_skip * x_noisy) / (c_out + 1e-7)
293
+
294
+ # Compute loss
295
+ loss = F.mse_loss(x_pred, v_target)
296
+ return loss
297
+
298
+
299
+ """
300
+ Diffusion Sampling
301
+ """
302
+
303
+ """ Schedules """
304
+
305
+
306
+ class Schedule(nn.Module):
307
+ """Interface used by different sampling schedules"""
308
+
309
+ def forward(self, num_steps: int, device: torch.device) -> Tensor:
310
+ raise NotImplementedError()
311
+
312
+
313
+ class LinearSchedule(Schedule):
314
+ def forward(self, num_steps: int, device: Any) -> Tensor:
315
+ sigmas = torch.linspace(1, 0, num_steps + 1)[:-1]
316
+ return sigmas
317
+
318
+
319
+ class KarrasSchedule(Schedule):
320
+ """https://arxiv.org/abs/2206.00364 equation 5"""
321
+
322
+ def __init__(self, sigma_min: float, sigma_max: float, rho: float = 7.0):
323
+ super().__init__()
324
+ self.sigma_min = sigma_min
325
+ self.sigma_max = sigma_max
326
+ self.rho = rho
327
+
328
+ def forward(self, num_steps: int, device: Any) -> Tensor:
329
+ rho_inv = 1.0 / self.rho
330
+ steps = torch.arange(num_steps, device=device, dtype=torch.float32)
331
+ sigmas = (
332
+ self.sigma_max ** rho_inv
333
+ + (steps / (num_steps - 1))
334
+ * (self.sigma_min ** rho_inv - self.sigma_max ** rho_inv)
335
+ ) ** self.rho
336
+ sigmas = F.pad(sigmas, pad=(0, 1), value=0.0)
337
+ return sigmas
338
+
339
+
340
+ """ Samplers """
341
+
342
+
343
+ class Sampler(nn.Module):
344
+
345
+ diffusion_types: List[Type[Diffusion]] = []
346
+
347
+ def forward(
348
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
349
+ ) -> Tensor:
350
+ raise NotImplementedError()
351
+
352
+ def inpaint(
353
+ self,
354
+ source: Tensor,
355
+ mask: Tensor,
356
+ fn: Callable,
357
+ sigmas: Tensor,
358
+ num_steps: int,
359
+ num_resamples: int,
360
+ ) -> Tensor:
361
+ raise NotImplementedError("Inpainting not available with current sampler")
362
+
363
+
364
+ class VSampler(Sampler):
365
+
366
+ diffusion_types = [VDiffusion]
367
+
368
+ def get_alpha_beta(self, sigma: float) -> Tuple[float, float]:
369
+ angle = sigma * pi / 2
370
+ alpha = cos(angle)
371
+ beta = sin(angle)
372
+ return alpha, beta
373
+
374
+ def forward(
375
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
376
+ ) -> Tensor:
377
+ x = sigmas[0] * noise
378
+ alpha, beta = self.get_alpha_beta(sigmas[0].item())
379
+
380
+ for i in range(num_steps - 1):
381
+ is_last = i == num_steps - 1
382
+
383
+ x_denoised = fn(x, sigma=sigmas[i])
384
+ x_pred = x * alpha - x_denoised * beta
385
+ x_eps = x * beta + x_denoised * alpha
386
+
387
+ if not is_last:
388
+ alpha, beta = self.get_alpha_beta(sigmas[i + 1].item())
389
+ x = x_pred * alpha + x_eps * beta
390
+
391
+ return x_pred
392
+
393
+
394
+ class KarrasSampler(Sampler):
395
+ """https://arxiv.org/abs/2206.00364 algorithm 1"""
396
+
397
+ diffusion_types = [KDiffusion, VKDiffusion]
398
+
399
+ def __init__(
400
+ self,
401
+ s_tmin: float = 0,
402
+ s_tmax: float = float("inf"),
403
+ s_churn: float = 0.0,
404
+ s_noise: float = 1.0,
405
+ ):
406
+ super().__init__()
407
+ self.s_tmin = s_tmin
408
+ self.s_tmax = s_tmax
409
+ self.s_noise = s_noise
410
+ self.s_churn = s_churn
411
+
412
+ def step(
413
+ self, x: Tensor, fn: Callable, sigma: float, sigma_next: float, gamma: float
414
+ ) -> Tensor:
415
+ """Algorithm 2 (step)"""
416
+ # Select temporarily increased noise level
417
+ sigma_hat = sigma + gamma * sigma
418
+ # Add noise to move from sigma to sigma_hat
419
+ epsilon = self.s_noise * torch.randn_like(x)
420
+ x_hat = x + sqrt(sigma_hat ** 2 - sigma ** 2) * epsilon
421
+ # Evaluate ∂x/∂sigma at sigma_hat
422
+ d = (x_hat - fn(x_hat, sigma=sigma_hat)) / sigma_hat
423
+ # Take euler step from sigma_hat to sigma_next
424
+ x_next = x_hat + (sigma_next - sigma_hat) * d
425
+ # Second order correction
426
+ if sigma_next != 0:
427
+ model_out_next = fn(x_next, sigma=sigma_next)
428
+ d_prime = (x_next - model_out_next) / sigma_next
429
+ x_next = x_hat + 0.5 * (sigma - sigma_hat) * (d + d_prime)
430
+ return x_next
431
+
432
+ def forward(
433
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
434
+ ) -> Tensor:
435
+ x = sigmas[0] * noise
436
+ # Compute gammas
437
+ gammas = torch.where(
438
+ (sigmas >= self.s_tmin) & (sigmas <= self.s_tmax),
439
+ min(self.s_churn / num_steps, sqrt(2) - 1),
440
+ 0.0,
441
+ )
442
+ # Denoise to sample
443
+ for i in range(num_steps - 1):
444
+ x = self.step(
445
+ x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1], gamma=gammas[i] # type: ignore # noqa
446
+ )
447
+
448
+ return x
449
+
450
+
451
+ class AEulerSampler(Sampler):
452
+
453
+ diffusion_types = [KDiffusion, VKDiffusion]
454
+
455
+ def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float]:
456
+ sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
457
+ sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
458
+ return sigma_up, sigma_down
459
+
460
+ def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
461
+ # Sigma steps
462
+ sigma_up, sigma_down = self.get_sigmas(sigma, sigma_next)
463
+ # Derivative at sigma (∂x/∂sigma)
464
+ d = (x - fn(x, sigma=sigma)) / sigma
465
+ # Euler method
466
+ x_next = x + d * (sigma_down - sigma)
467
+ # Add randomness
468
+ x_next = x_next + torch.randn_like(x) * sigma_up
469
+ return x_next
470
+
471
+ def forward(
472
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
473
+ ) -> Tensor:
474
+ x = sigmas[0] * noise
475
+ # Denoise to sample
476
+ for i in range(num_steps - 1):
477
+ x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
478
+ return x
479
+
480
+
481
+ class ADPM2Sampler(Sampler):
482
+ """https://www.desmos.com/calculator/jbxjlqd9mb"""
483
+
484
+ diffusion_types = [KDiffusion, VKDiffusion]
485
+
486
+ def __init__(self, rho: float = 1.0):
487
+ super().__init__()
488
+ self.rho = rho
489
+
490
+ def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float, float]:
491
+ r = self.rho
492
+ sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
493
+ sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
494
+ sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r
495
+ return sigma_up, sigma_down, sigma_mid
496
+
497
+ def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
498
+ # Sigma steps
499
+ sigma_up, sigma_down, sigma_mid = self.get_sigmas(sigma, sigma_next)
500
+ # Derivative at sigma (∂x/∂sigma)
501
+ d = (x - fn(x, sigma=sigma)) / sigma
502
+ # Denoise to midpoint
503
+ x_mid = x + d * (sigma_mid - sigma)
504
+ # Derivative at sigma_mid (∂x_mid/∂sigma_mid)
505
+ d_mid = (x_mid - fn(x_mid, sigma=sigma_mid)) / sigma_mid
506
+ # Denoise to next
507
+ x = x + d_mid * (sigma_down - sigma)
508
+ # Add randomness
509
+ x_next = x + torch.randn_like(x) * sigma_up
510
+ return x_next
511
+
512
+ def forward(
513
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
514
+ ) -> Tensor:
515
+ x = sigmas[0] * noise
516
+ # Denoise to sample
517
+ for i in range(num_steps - 1):
518
+ x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
519
+ return x
520
+
521
+ def inpaint(
522
+ self,
523
+ source: Tensor,
524
+ mask: Tensor,
525
+ fn: Callable,
526
+ sigmas: Tensor,
527
+ num_steps: int,
528
+ num_resamples: int,
529
+ ) -> Tensor:
530
+ x = sigmas[0] * torch.randn_like(source)
531
+
532
+ for i in range(num_steps - 1):
533
+ # Noise source to current noise level
534
+ source_noisy = source + sigmas[i] * torch.randn_like(source)
535
+ for r in range(num_resamples):
536
+ # Merge noisy source and current then denoise
537
+ x = source_noisy * mask + x * ~mask
538
+ x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
539
+ # Renoise if not last resample step
540
+ if r < num_resamples - 1:
541
+ sigma = sqrt(sigmas[i] ** 2 - sigmas[i + 1] ** 2)
542
+ x = x + sigma * torch.randn_like(x)
543
+
544
+ return source * mask + x * ~mask
545
+
546
+
547
+ """ Main Classes """
548
+
549
+
550
+ class DiffusionSampler(nn.Module):
551
+ def __init__(
552
+ self,
553
+ diffusion: Diffusion,
554
+ *,
555
+ sampler: Sampler,
556
+ sigma_schedule: Schedule,
557
+ num_steps: Optional[int] = None,
558
+ clamp: bool = True,
559
+ ):
560
+ super().__init__()
561
+ self.denoise_fn = diffusion.denoise_fn
562
+ self.sampler = sampler
563
+ self.sigma_schedule = sigma_schedule
564
+ self.num_steps = num_steps
565
+ self.clamp = clamp
566
+
567
+ # Check sampler is compatible with diffusion type
568
+ sampler_class = sampler.__class__.__name__
569
+ diffusion_class = diffusion.__class__.__name__
570
+ message = f"{sampler_class} incompatible with {diffusion_class}"
571
+ assert diffusion.alias in [t.alias for t in sampler.diffusion_types], message
572
+
573
+ def forward(
574
+ self, noise: Tensor, num_steps: Optional[int] = None, **kwargs
575
+ ) -> Tensor:
576
+ device = noise.device
577
+ num_steps = default(num_steps, self.num_steps) # type: ignore
578
+ assert exists(num_steps), "Parameter `num_steps` must be provided"
579
+ # Compute sigmas using schedule
580
+ sigmas = self.sigma_schedule(num_steps, device)
581
+ # Append additional kwargs to denoise function (used e.g. for conditional unet)
582
+ fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs}) # noqa
583
+ # Sample using sampler
584
+ x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps)
585
+ x = x.clamp(-1.0, 1.0) if self.clamp else x
586
+ return x
587
+
588
+
589
+ class DiffusionInpainter(nn.Module):
590
+ def __init__(
591
+ self,
592
+ diffusion: Diffusion,
593
+ *,
594
+ num_steps: int,
595
+ num_resamples: int,
596
+ sampler: Sampler,
597
+ sigma_schedule: Schedule,
598
+ ):
599
+ super().__init__()
600
+ self.denoise_fn = diffusion.denoise_fn
601
+ self.num_steps = num_steps
602
+ self.num_resamples = num_resamples
603
+ self.inpaint_fn = sampler.inpaint
604
+ self.sigma_schedule = sigma_schedule
605
+
606
+ @torch.no_grad()
607
+ def forward(self, inpaint: Tensor, inpaint_mask: Tensor) -> Tensor:
608
+ x = self.inpaint_fn(
609
+ source=inpaint,
610
+ mask=inpaint_mask,
611
+ fn=self.denoise_fn,
612
+ sigmas=self.sigma_schedule(self.num_steps, inpaint.device),
613
+ num_steps=self.num_steps,
614
+ num_resamples=self.num_resamples,
615
+ )
616
+ return x
617
+
618
+
619
+ def sequential_mask(like: Tensor, start: int) -> Tensor:
620
+ length, device = like.shape[2], like.device
621
+ mask = torch.ones_like(like, dtype=torch.bool)
622
+ mask[:, :, start:] = torch.zeros((length - start,), device=device)
623
+ return mask
624
+
625
+
626
+ class SpanBySpanComposer(nn.Module):
627
+ def __init__(
628
+ self,
629
+ inpainter: DiffusionInpainter,
630
+ *,
631
+ num_spans: int,
632
+ ):
633
+ super().__init__()
634
+ self.inpainter = inpainter
635
+ self.num_spans = num_spans
636
+
637
+ def forward(self, start: Tensor, keep_start: bool = False) -> Tensor:
638
+ half_length = start.shape[2] // 2
639
+
640
+ spans = list(start.chunk(chunks=2, dim=-1)) if keep_start else []
641
+ # Inpaint second half from first half
642
+ inpaint = torch.zeros_like(start)
643
+ inpaint[:, :, :half_length] = start[:, :, half_length:]
644
+ inpaint_mask = sequential_mask(like=start, start=half_length)
645
+
646
+ for i in range(self.num_spans):
647
+ # Inpaint second half
648
+ span = self.inpainter(inpaint=inpaint, inpaint_mask=inpaint_mask)
649
+ # Replace first half with generated second half
650
+ second_half = span[:, :, half_length:]
651
+ inpaint[:, :, :half_length] = second_half
652
+ # Save generated span
653
+ spans.append(second_half)
654
+
655
+ return torch.cat(spans, dim=2)
656
+
657
+
658
+ class XDiffusion(nn.Module):
659
+ def __init__(self, type: str, net: nn.Module, **kwargs):
660
+ super().__init__()
661
+
662
+ diffusion_classes = [VDiffusion, KDiffusion, VKDiffusion]
663
+ aliases = [t.alias for t in diffusion_classes] # type: ignore
664
+ message = f"type='{type}' must be one of {*aliases,}"
665
+ assert type in aliases, message
666
+ self.net = net
667
+
668
+ for XDiffusion in diffusion_classes:
669
+ if XDiffusion.alias == type: # type: ignore
670
+ self.diffusion = XDiffusion(net=net, **kwargs)
671
+
672
+ def forward(self, *args, **kwargs) -> Tensor:
673
+ return self.diffusion(*args, **kwargs)
674
+
675
+ def sample(
676
+ self,
677
+ noise: Tensor,
678
+ num_steps: int,
679
+ sigma_schedule: Schedule,
680
+ sampler: Sampler,
681
+ clamp: bool,
682
+ **kwargs,
683
+ ) -> Tensor:
684
+ diffusion_sampler = DiffusionSampler(
685
+ diffusion=self.diffusion,
686
+ sampler=sampler,
687
+ sigma_schedule=sigma_schedule,
688
+ num_steps=num_steps,
689
+ clamp=clamp,
690
+ )
691
+ return diffusion_sampler(noise, **kwargs)
Modules/diffusion/utils.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import reduce
2
+ from inspect import isfunction
3
+ from math import ceil, floor, log2, pi
4
+ from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from torch import Generator, Tensor
10
+ from typing_extensions import TypeGuard
11
+
12
+ T = TypeVar("T")
13
+
14
+
15
+ def exists(val: Optional[T]) -> TypeGuard[T]:
16
+ return val is not None
17
+
18
+
19
+ def iff(condition: bool, value: T) -> Optional[T]:
20
+ return value if condition else None
21
+
22
+
23
+ def is_sequence(obj: T) -> TypeGuard[Union[list, tuple]]:
24
+ return isinstance(obj, list) or isinstance(obj, tuple)
25
+
26
+
27
+ def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
28
+ if exists(val):
29
+ return val
30
+ return d() if isfunction(d) else d
31
+
32
+
33
+ def to_list(val: Union[T, Sequence[T]]) -> List[T]:
34
+ if isinstance(val, tuple):
35
+ return list(val)
36
+ if isinstance(val, list):
37
+ return val
38
+ return [val] # type: ignore
39
+
40
+
41
+ def prod(vals: Sequence[int]) -> int:
42
+ return reduce(lambda x, y: x * y, vals)
43
+
44
+
45
+ def closest_power_2(x: float) -> int:
46
+ exponent = log2(x)
47
+ distance_fn = lambda z: abs(x - 2 ** z) # noqa
48
+ exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
49
+ return 2 ** int(exponent_closest)
50
+
51
+ def rand_bool(shape, proba, device = None):
52
+ if proba == 1:
53
+ return torch.ones(shape, device=device, dtype=torch.bool)
54
+ elif proba == 0:
55
+ return torch.zeros(shape, device=device, dtype=torch.bool)
56
+ else:
57
+ return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
58
+
59
+
60
+ """
61
+ Kwargs Utils
62
+ """
63
+
64
+
65
+ def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
66
+ return_dicts: Tuple[Dict, Dict] = ({}, {})
67
+ for key in d.keys():
68
+ no_prefix = int(not key.startswith(prefix))
69
+ return_dicts[no_prefix][key] = d[key]
70
+ return return_dicts
71
+
72
+
73
+ def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
74
+ kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
75
+ if keep_prefix:
76
+ return kwargs_with_prefix, kwargs
77
+ kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
78
+ return kwargs_no_prefix, kwargs
79
+
80
+
81
+ def prefix_dict(prefix: str, d: Dict) -> Dict:
82
+ return {prefix + str(k): v for k, v in d.items()}