coralLight commited on
Commit
012e1d0
·
1 Parent(s): 680053b

add inference

Browse files
Files changed (8) hide show
  1. NoiseTransformer.py +26 -0
  2. SVDNoiseUnet.py +430 -0
  3. app.py +330 -0
  4. dpm_solver_v3.py +904 -0
  5. free_lunch_utils.py +303 -0
  6. requirements.txt +11 -0
  7. sampler.py +315 -0
  8. uni_pc.py +757 -0
NoiseTransformer.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from torch.nn import functional as F
4
+ from timm import create_model
5
+
6
+
7
+ __all__ = ['NoiseTransformer']
8
+
9
+ class NoiseTransformer(nn.Module):
10
+ def __init__(self, resolution=(128,96)):
11
+ super().__init__()
12
+ self.upsample = lambda x: F.interpolate(x, [224,224])
13
+ self.downsample = lambda x: F.interpolate(x, [resolution[0],resolution[1]])
14
+ self.upconv = nn.Conv2d(7,4,(1,1),(1,1),(0,0))
15
+ self.downconv = nn.Conv2d(4,3,(1,1),(1,1),(0,0))
16
+ # self.upconv = nn.Conv2d(7,4,(1,1),(1,1),(0,0))
17
+ self.swin = create_model("swin_tiny_patch4_window7_224",pretrained=True)
18
+
19
+
20
+ def forward(self, x, residual=False):
21
+ if residual:
22
+ x = self.upconv(self.downsample(self.swin.forward_features(self.downconv(self.upsample(x))))) + x
23
+ else:
24
+ x = self.upconv(self.downsample(self.swin.forward_features(self.downconv(self.upsample(x)))))
25
+
26
+ return x
SVDNoiseUnet.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import einops
4
+
5
+ from torch.nn import functional as F
6
+ from torch.jit import Final
7
+ from timm.layers import use_fused_attn
8
+ from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, get_act_layer
9
+ from abc import abstractmethod
10
+ from NoiseTransformer import NoiseTransformer
11
+ from einops import rearrange
12
+ __all__ = ['SVDNoiseUnet', 'SVDNoiseUnet_Concise']
13
+
14
+ class Attention(nn.Module):
15
+ fused_attn: Final[bool]
16
+
17
+ def __init__(
18
+ self,
19
+ dim: int,
20
+ num_heads: int = 8,
21
+ qkv_bias: bool = False,
22
+ qk_norm: bool = False,
23
+ attn_drop: float = 0.,
24
+ proj_drop: float = 0.,
25
+ norm_layer: nn.Module = nn.LayerNorm,
26
+ ) -> None:
27
+ super().__init__()
28
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
29
+ self.num_heads = num_heads
30
+ self.head_dim = dim // num_heads
31
+ self.scale = self.head_dim ** -0.5
32
+ self.fused_attn = use_fused_attn()
33
+
34
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
35
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
36
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
37
+ self.attn_drop = nn.Dropout(attn_drop)
38
+ self.proj = nn.Linear(dim, dim)
39
+ self.proj_drop = nn.Dropout(proj_drop)
40
+
41
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
42
+ B, N, C = x.shape
43
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
44
+ q, k, v = qkv.unbind(0)
45
+ q, k = self.q_norm(q), self.k_norm(k)
46
+
47
+ if self.fused_attn:
48
+ x = F.scaled_dot_product_attention(
49
+ q, k, v,
50
+ dropout_p=self.attn_drop.p if self.training else 0.,
51
+ )
52
+ else:
53
+ q = q * self.scale
54
+ attn = q @ k.transpose(-2, -1)
55
+ attn = attn.softmax(dim=-1)
56
+ attn = self.attn_drop(attn)
57
+ x = attn @ v
58
+
59
+ x = x.transpose(1, 2).reshape(B, N, C)
60
+ x = self.proj(x)
61
+ x = self.proj_drop(x)
62
+ return x
63
+
64
+
65
+ class SVDNoiseUnet(nn.Module):
66
+ def __init__(self, in_channels=4, out_channels=4, resolution=(128,96)): # resolution = size // 8
67
+ super(SVDNoiseUnet, self).__init__()
68
+
69
+ _in_1 = int(resolution[0] * in_channels // 2)
70
+ _out_1 = int(resolution[0] * out_channels // 2)
71
+
72
+ _in_2 = int(resolution[1] * in_channels // 2)
73
+ _out_2 = int(resolution[1] * out_channels // 2)
74
+ self.mlp1 = nn.Sequential(
75
+ nn.Linear(_in_1, 64),
76
+ nn.ReLU(inplace=True),
77
+ nn.Linear(64, _out_1),
78
+ )
79
+ self.mlp2 = nn.Sequential(
80
+ nn.Linear(_in_2, 64),
81
+ nn.ReLU(inplace=True),
82
+ nn.Linear(64, _out_2),
83
+ )
84
+
85
+ self.mlp3 = nn.Sequential(
86
+ nn.Linear(_in_2, _out_2),
87
+ )
88
+
89
+ self.attention = Attention(_out_2)
90
+
91
+ self.bn = nn.BatchNorm1d(256)
92
+ self.bn2 = nn.BatchNorm1d(192)
93
+
94
+ self.mlp4 = nn.Sequential(
95
+ nn.Linear(_out_2, 1024),
96
+ nn.ReLU(inplace=True),
97
+ nn.Linear(1024, _out_2),
98
+ )
99
+ self.ffn = nn.Sequential(
100
+ nn.Linear(256, 384), # Expand
101
+ nn.ReLU(inplace=True),
102
+ nn.Linear(384, 192) # Reduce to target size
103
+ )
104
+ self.ffn2 = nn.Sequential(
105
+ nn.Linear(256, 384), # Expand
106
+ nn.ReLU(inplace=True),
107
+ nn.Linear(384, 192) # Reduce to target size
108
+ )
109
+ # self.adaptive_pool = nn.AdaptiveAvgPool2d((256, 192))
110
+
111
+ def forward(self, x, residual=False):
112
+ b, c, h, w = x.shape
113
+ x = einops.rearrange(x, "b (a c)h w ->b (a h)(c w)", a=2,c=2) # x -> [1, 256, 256]
114
+ U, s, V = torch.linalg.svd(x) # U->[b 256 256], s-> [b 256], V->[b 256 256]
115
+ U_T = U.permute(0, 2, 1)
116
+ U_out = self.ffn(self.mlp1(U_T))
117
+ U_out = self.bn(U_out)
118
+ U_out = U_out.transpose(1, 2)
119
+ U_out = self.ffn2(U_out) # [b, 256, 256] -> [b, 256, 192]
120
+ U_out = self.bn2(U_out)
121
+ U_out = U_out.transpose(1, 2)
122
+ # U_out = self.bn(U_out)
123
+ V_out = self.mlp2(V)
124
+ s_out = self.mlp3(s).unsqueeze(1) # s -> [b, 1, 256] => [b, 256, 256]
125
+ out = U_out + V_out + s_out
126
+ # print(out.size())
127
+ out = out.squeeze(1)
128
+ out = self.attention(out).mean(1)
129
+ out = self.mlp4(out) + s
130
+ diagonal_out = torch.diag_embed(out)
131
+ padded_diag = F.pad(diagonal_out, (0, 0, 0, 64), mode='constant', value=0) # Shape: [b, 1, 256, 192]
132
+ pred = U @ padded_diag @ V
133
+ return einops.rearrange(pred, "b (a h)(c w) -> b (a c) h w", a=2,c=2)
134
+
135
+ class SVDNoiseUnet64(nn.Module):
136
+ def __init__(self, in_channels=4, out_channels=4, resolution=64): # resolution = size // 8
137
+ super(SVDNoiseUnet64, self).__init__()
138
+
139
+ _in = int(resolution * in_channels // 2)
140
+ _out = int(resolution * out_channels // 2)
141
+ self.mlp1 = nn.Sequential(
142
+ nn.Linear(_in, 64),
143
+ nn.ReLU(inplace=True),
144
+ nn.Linear(64, _out),
145
+ )
146
+ self.mlp2 = nn.Sequential(
147
+ nn.Linear(_in, 64),
148
+ nn.ReLU(inplace=True),
149
+ nn.Linear(64, _out),
150
+ )
151
+
152
+ self.mlp3 = nn.Sequential(
153
+ nn.Linear(_in, _out),
154
+ )
155
+
156
+ self.attention = Attention(_out)
157
+
158
+ self.bn = nn.BatchNorm2d(_out)
159
+
160
+ self.mlp4 = nn.Sequential(
161
+ nn.Linear(_out, 1024),
162
+ nn.ReLU(inplace=True),
163
+ nn.Linear(1024, _out),
164
+ )
165
+
166
+ def forward(self, x, residual=False):
167
+ b, c, h, w = x.shape
168
+ x = einops.rearrange(x, "b (a c)h w ->b (a h)(c w)", a=2,c=2) # x -> [1, 256, 256]
169
+ U, s, V = torch.linalg.svd(x) # U->[b 256 256], s-> [b 256], V->[b 256 256]
170
+ U_T = U.permute(0, 2, 1)
171
+ out = self.mlp1(U_T) + self.mlp2(V) + self.mlp3(s).unsqueeze(1) # s -> [b, 1, 256] => [b, 256, 256]
172
+ out = self.attention(out).mean(1)
173
+ out = self.mlp4(out) + s
174
+ pred = U @ torch.diag_embed(out) @ V
175
+ return einops.rearrange(pred, "b (a h)(c w) -> b (a c) h w", a=2,c=2)
176
+
177
+
178
+
179
+ class SVDNoiseUnet128(nn.Module):
180
+ def __init__(self, in_channels=4, out_channels=4, resolution=128): # resolution = size // 8
181
+ super(SVDNoiseUnet128, self).__init__()
182
+
183
+ _in = int(resolution * in_channels // 2)
184
+ _out = int(resolution * out_channels // 2)
185
+ self.mlp1 = nn.Sequential(
186
+ nn.Linear(_in, 64),
187
+ nn.ReLU(inplace=True),
188
+ nn.Linear(64, _out),
189
+ )
190
+ self.mlp2 = nn.Sequential(
191
+ nn.Linear(_in, 64),
192
+ nn.ReLU(inplace=True),
193
+ nn.Linear(64, _out),
194
+ )
195
+
196
+ self.mlp3 = nn.Sequential(
197
+ nn.Linear(_in, _out),
198
+ )
199
+
200
+ self.attention = Attention(_out)
201
+
202
+ self.bn = nn.BatchNorm2d(_out)
203
+
204
+ self.mlp4 = nn.Sequential(
205
+ nn.Linear(_out, 1024),
206
+ nn.ReLU(inplace=True),
207
+ nn.Linear(1024, _out),
208
+ )
209
+
210
+ def forward(self, x, residual=False):
211
+ b, c, h, w = x.shape
212
+ x = einops.rearrange(x, "b (a c)h w ->b (a h)(c w)", a=2,c=2) # x -> [1, 256, 256]
213
+ U, s, V = torch.linalg.svd(x) # U->[b 256 256], s-> [b 256], V->[b 256 256]
214
+ U_T = U.permute(0, 2, 1)
215
+ out = self.mlp1(U_T) + self.mlp2(V) + self.mlp3(s).unsqueeze(1) # s -> [b, 1, 256] => [b, 256, 256]
216
+ out = self.attention(out).mean(1)
217
+ out = self.mlp4(out) + s
218
+ pred = U @ torch.diag_embed(out) @ V
219
+ return einops.rearrange(pred, "b (a h)(c w) -> b (a c) h w", a=2,c=2)
220
+
221
+
222
+
223
+ class SVDNoiseUnet_Concise(nn.Module):
224
+ def __init__(self, in_channels=4, out_channels=4, resolution=64):
225
+ super(SVDNoiseUnet_Concise, self).__init__()
226
+
227
+
228
+ from diffusers.models.normalization import AdaGroupNorm
229
+
230
+ class NPNet(nn.Module):
231
+ def __init__(self, model_id, pretrained_path=' ', device='cuda') -> None:
232
+ super(NPNet, self).__init__()
233
+
234
+ assert model_id in ['SD1.5', 'DreamShaper', 'DiT']
235
+
236
+ self.model_id = model_id
237
+ self.device = device
238
+ self.pretrained_path = pretrained_path
239
+
240
+ (
241
+ self.unet_svd,
242
+ self.unet_embedding,
243
+ self.text_embedding,
244
+ self._alpha,
245
+ self._beta
246
+ ) = self.get_model()
247
+ def save_model(self, save_path: str):
248
+ """
249
+ Save this NPNet so that get_model() can later reload it.
250
+ """
251
+ torch.save({
252
+ "unet_svd": self.unet_svd.state_dict(),
253
+ "unet_embedding": self.unet_embedding.state_dict(),
254
+ "embeeding": self.text_embedding.state_dict(), # matches get_model’s key
255
+ "alpha": self._alpha,
256
+ "beta": self._beta,
257
+ }, save_path)
258
+ print(f"NPNet saved to {save_path}")
259
+ def get_model(self):
260
+
261
+ unet_embedding = NoiseTransformer(resolution=(128,96)).to(self.device).to(torch.float32)
262
+ unet_svd = SVDNoiseUnet(resolution=(128,96)).to(self.device).to(torch.float32)
263
+
264
+ if self.model_id == 'DiT':
265
+ text_embedding = AdaGroupNorm(768 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
266
+ else:
267
+ text_embedding = AdaGroupNorm(768 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
268
+
269
+ # initialize random _alpha and _beta when no checkpoint is provided
270
+ _alpha = torch.randn(1, device=self.device)
271
+ _beta = torch.randn(1, device=self.device)
272
+
273
+ if '.pth' in self.pretrained_path:
274
+ gloden_unet = torch.load(self.pretrained_path)
275
+ unet_svd.load_state_dict(gloden_unet["unet_svd"],strict=True)
276
+ unet_embedding.load_state_dict(gloden_unet["unet_embedding"],strict=True)
277
+ text_embedding.load_state_dict(gloden_unet["embeeding"],strict=True)
278
+ _alpha = gloden_unet["alpha"]
279
+ _beta = gloden_unet["beta"]
280
+
281
+ print("Load Successfully!")
282
+
283
+ return unet_svd, unet_embedding, text_embedding, _alpha, _beta
284
+
285
+ else:
286
+ return unet_svd, unet_embedding, text_embedding, _alpha, _beta
287
+
288
+
289
+ def forward(self, initial_noise, prompt_embeds):
290
+
291
+ prompt_embeds = prompt_embeds.float().view(prompt_embeds.shape[0], -1)
292
+ text_emb = self.text_embedding(initial_noise.float(), prompt_embeds)
293
+
294
+ encoder_hidden_states_svd = initial_noise
295
+ encoder_hidden_states_embedding = initial_noise + text_emb
296
+
297
+ golden_embedding = self.unet_embedding(encoder_hidden_states_embedding.float())
298
+
299
+ golden_noise = self.unet_svd(encoder_hidden_states_svd.float()) + (
300
+ 2 * torch.sigmoid(self._alpha) - 1) * text_emb + self._beta * golden_embedding
301
+
302
+ return golden_noise
303
+
304
+
305
+ class NPNet64(nn.Module):
306
+ def __init__(self, model_id, pretrained_path=' ', device='cuda') -> None:
307
+ super(NPNet64, self).__init__()
308
+ self.model_id = model_id
309
+ self.device = device
310
+ self.pretrained_path = pretrained_path
311
+
312
+ (
313
+ self.unet_svd,
314
+ self.unet_embedding,
315
+ self.text_embedding,
316
+ self._alpha,
317
+ self._beta
318
+ ) = self.get_model()
319
+
320
+ def save_model(self, save_path: str):
321
+ """
322
+ Save this NPNet so that get_model() can later reload it.
323
+ """
324
+ torch.save({
325
+ "unet_svd": self.unet_svd.state_dict(),
326
+ "unet_embedding": self.unet_embedding.state_dict(),
327
+ "embeeding": self.text_embedding.state_dict(), # matches get_model’s key
328
+ "alpha": self._alpha,
329
+ "beta": self._beta,
330
+ }, save_path)
331
+ print(f"NPNet saved to {save_path}")
332
+
333
+ def get_model(self):
334
+
335
+ unet_embedding = NoiseTransformer(resolution=(64,64)).to(self.device).to(torch.float32)
336
+ unet_svd = SVDNoiseUnet64(resolution=64).to(self.device).to(torch.float32)
337
+ _alpha = torch.randn(1, device=self.device)
338
+ _beta = torch.randn(1, device=self.device)
339
+
340
+ text_embedding = AdaGroupNorm(768 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
341
+
342
+
343
+ if '.pth' in self.pretrained_path:
344
+ gloden_unet = torch.load(self.pretrained_path)
345
+ unet_svd.load_state_dict(gloden_unet["unet_svd"])
346
+ unet_embedding.load_state_dict(gloden_unet["unet_embedding"])
347
+ text_embedding.load_state_dict(gloden_unet["embeeding"])
348
+ _alpha = gloden_unet["alpha"]
349
+ _beta = gloden_unet["beta"]
350
+
351
+ print("Load Successfully!")
352
+
353
+ return unet_svd, unet_embedding, text_embedding, _alpha, _beta
354
+
355
+
356
+ def forward(self, initial_noise, prompt_embeds):
357
+
358
+ prompt_embeds = prompt_embeds.float().view(prompt_embeds.shape[0], -1)
359
+ text_emb = self.text_embedding(initial_noise.float(), prompt_embeds)
360
+
361
+ encoder_hidden_states_svd = initial_noise
362
+ encoder_hidden_states_embedding = initial_noise + text_emb
363
+
364
+ golden_embedding = self.unet_embedding(encoder_hidden_states_embedding.float())
365
+
366
+ golden_noise = self.unet_svd(encoder_hidden_states_svd.float()) + (
367
+ 2 * torch.sigmoid(self._alpha) - 1) * text_emb + self._beta * golden_embedding
368
+
369
+ return golden_noise
370
+
371
+ class NPNet128(nn.Module):
372
+ def __init__(self, model_id, pretrained_path=True, device='cuda') -> None:
373
+ super(NPNet128, self).__init__()
374
+
375
+ assert model_id in ['SDXL', 'DreamShaper', 'DiT']
376
+
377
+ self.model_id = model_id
378
+ self.device = device
379
+ self.pretrained_path = pretrained_path
380
+
381
+ (
382
+ self.unet_svd,
383
+ self.unet_embedding,
384
+ self.text_embedding,
385
+ self._alpha,
386
+ self._beta
387
+ ) = self.get_model()
388
+
389
+ def get_model(self):
390
+
391
+ unet_embedding = NoiseTransformer(resolution=(128,128)).to(self.device).to(torch.float32)
392
+ unet_svd = SVDNoiseUnet128(resolution=128).to(self.device).to(torch.float32)
393
+
394
+ if self.model_id == 'DiT':
395
+ text_embedding = AdaGroupNorm(1024 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
396
+ else:
397
+ text_embedding = AdaGroupNorm(2048 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
398
+
399
+
400
+ if '.pth' in self.pretrained_path:
401
+ gloden_unet = torch.load(self.pretrained_path)
402
+ unet_svd.load_state_dict(gloden_unet["unet_svd"])
403
+ unet_embedding.load_state_dict(gloden_unet["unet_embedding"])
404
+ text_embedding.load_state_dict(gloden_unet["embeeding"])
405
+ _alpha = gloden_unet["alpha"]
406
+ _beta = gloden_unet["beta"]
407
+
408
+ print("Load Successfully!")
409
+
410
+ return unet_svd, unet_embedding, text_embedding, _alpha, _beta
411
+
412
+ else:
413
+ assert ("No Pretrained Weights Found!")
414
+
415
+
416
+ def forward(self, initial_noise, prompt_embeds):
417
+
418
+ prompt_embeds = prompt_embeds.float().view(prompt_embeds.shape[0], -1)
419
+ text_emb = self.text_embedding(initial_noise.float(), prompt_embeds)
420
+
421
+ encoder_hidden_states_svd = initial_noise
422
+ encoder_hidden_states_embedding = initial_noise + text_emb
423
+
424
+ golden_embedding = self.unet_embedding(encoder_hidden_states_embedding.float())
425
+
426
+ golden_noise = self.unet_svd(encoder_hidden_states_svd.float()) + (
427
+ 2 * torch.sigmoid(self._alpha) - 1) * text_emb + self._beta * golden_embedding
428
+
429
+ return golden_noise
430
+
app.py CHANGED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import random
4
+ import json
5
+ import spaces #[uncomment to use ZeroGPU]
6
+ from diffusers import (
7
+ AutoencoderKL,
8
+ StableDiffusionXLPipeline,
9
+ )
10
+ from huggingface_hub import login, hf_hub_download
11
+ from PIL import Image
12
+ # from huggingface_hub import login
13
+ from SVDNoiseUnet import NPNet64
14
+ import functools
15
+ import random
16
+ from free_lunch_utils import register_free_upblock2d, register_free_crossattn_upblock2d
17
+ import torch
18
+ import torch.nn as nn
19
+ from einops import rearrange
20
+ from torchvision.utils import make_grid
21
+ import time
22
+ from pytorch_lightning import seed_everything
23
+ from torch import autocast
24
+ from contextlib import contextmanager, nullcontext
25
+ import accelerate
26
+ import torchsde
27
+ from SVDNoiseUnet import NPNet128
28
+ from tqdm import tqdm, trange
29
+ from itertools import islice
30
+ device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ model_repo_id = "Lykon/dreamshaper-xl-1-0" # Replace to the model you would like to use
32
+ from sampler import UniPCSampler
33
+
34
+ precision_scope = autocast
35
+
36
+ def chunk(it, size):
37
+ it = iter(it)
38
+ return iter(lambda: tuple(islice(it, size)), ())
39
+
40
+
41
+ def numpy_to_pil(images):
42
+ """
43
+ Convert a numpy image or a batch of images to a PIL image.
44
+ """
45
+ if images.ndim == 3:
46
+ images = images[None, ...]
47
+ images = (images * 255).round().astype("uint8")
48
+ pil_images = [Image.fromarray(image) for image in images]
49
+
50
+ return pil_images
51
+
52
+
53
+ def load_replacement(x):
54
+ try:
55
+ hwc = x.shape
56
+ y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
57
+ y = (np.array(y) / 255.0).astype(x.dtype)
58
+ assert y.shape == x.shape
59
+ return y
60
+ except Exception:
61
+ return x
62
+
63
+
64
+ # Adapted from pipelines.StableDiffusionPipeline.encode_prompt
65
+ def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train=True):
66
+ captions = []
67
+ for caption in prompt_batch:
68
+ if random.random() < proportion_empty_prompts:
69
+ captions.append("")
70
+ elif isinstance(caption, str):
71
+ captions.append(caption)
72
+ elif isinstance(caption, (list, np.ndarray)):
73
+ # take a random caption if there are multiple
74
+ captions.append(random.choice(caption) if is_train else caption[0])
75
+
76
+ with torch.no_grad():
77
+ text_inputs = tokenizer(
78
+ captions,
79
+ padding="max_length",
80
+ max_length=tokenizer.model_max_length,
81
+ truncation=True,
82
+ return_tensors="pt",
83
+ )
84
+ text_input_ids = text_inputs.input_ids
85
+ prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device))[0]
86
+
87
+ return prompt_embeds
88
+
89
+ def chunk(it, size):
90
+ it = iter(it)
91
+ return iter(lambda: tuple(islice(it, size)), ())
92
+
93
+ def convert_caption_json_to_str(json):
94
+ caption = json["caption"]
95
+ return caption
96
+
97
+ def prepare_sdxl_pipeline_step_parameter(pipe, prompts, need_cfg, device, negative_prompts, W = 1024, H = 1024):
98
+ (
99
+ prompt_embeds,
100
+ negative_prompt_embeds,
101
+ pooled_prompt_embeds,
102
+ negative_pooled_prompt_embeds,
103
+ ) = pipe.encode_prompt(
104
+ prompt=prompts,
105
+ negative_prompt=negative_prompts,
106
+ device=device,
107
+ do_classifier_free_guidance=need_cfg,
108
+ )
109
+ # timesteps = pipe.scheduler.timesteps
110
+
111
+ prompt_embeds = prompt_embeds.to(device)
112
+ add_text_embeds = pooled_prompt_embeds.to(device)
113
+ original_size = (W, H)
114
+ crops_coords_top_left = (0, 0)
115
+ target_size = (W, H)
116
+ text_encoder_projection_dim = None
117
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
118
+ if pipe.text_encoder_2 is None:
119
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
120
+ else:
121
+ text_encoder_projection_dim = pipe.text_encoder_2.config.projection_dim
122
+ passed_add_embed_dim = (
123
+ pipe.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
124
+ )
125
+ expected_add_embed_dim = pipe.unet.add_embedding.linear_1.in_features
126
+ if expected_add_embed_dim != passed_add_embed_dim:
127
+ raise ValueError(
128
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
129
+ )
130
+ add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
131
+ add_time_ids = add_time_ids.to(device)
132
+ negative_add_time_ids = add_time_ids
133
+
134
+ if need_cfg:
135
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
136
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
137
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
138
+ ret_dict = {
139
+ "text_embeds": add_text_embeds,
140
+ "time_ids": add_time_ids
141
+ }
142
+ return prompt_embeds, ret_dict
143
+
144
+
145
+ def model_closure(pipe):
146
+ def model_fn(x, t, c):
147
+ prompt = c[0]
148
+ cond_kwargs = c[1] if len(c) > 1 else None
149
+ # prompt_embeds, cond_kwargs = prepare_sdxl_pipeline_step_parameter(pipe=pipe,prompts = prompt, need_cfg=True, device=pipe.device,negative_prompts=negative_prompt)
150
+ # prompt_embeds, cond_kwargs = c
151
+ return pipe.unet(x
152
+ , t
153
+ , encoder_hidden_states=prompt.to(device=x.device, dtype=x.dtype)
154
+ , added_cond_kwargs=cond_kwargs).sample
155
+
156
+ return model_fn
157
+
158
+
159
+ torch_dtype = torch.float16
160
+ repo_id = "madebyollin/sdxl-vae-fp16-fix" # e.g., "distilbert/distilgpt2"
161
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix",torch_dtype=torch_dtype) #from_single_file(downloaded_path, torch_dtype=torch_dtype)
162
+ vae.to('cuda')
163
+
164
+ pipe = StableDiffusionXLPipeline.from_pretrained("John6666/illustrij-evo-lvl3-sdxl",torch_dtype=torch_dtype,vae=vae)
165
+ # pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0",torch_dtype=torch.float16,vae=vae)
166
+
167
+ pipe.to('cuda')
168
+
169
+
170
+
171
+ MAX_SEED = np.iinfo(np.int32).max
172
+ MAX_IMAGE_SIZE = 1024
173
+
174
+ accelerator = accelerate.Accelerator()
175
+
176
+ def generate_image_with_steps(prompt, negative_prompt, seed, width, height, guidance_scale, num_inference_steps):
177
+ """Helper function to generate image with specific number of steps"""
178
+ prompts = [prompt]
179
+ sampler = UniPCSampler(pipe,model_closure=model_closure, steps=num_inference_steps, guidance_scale=guidance_scale)
180
+ c = prompts
181
+ uc = [negative_prompt] * len(c) if guidance_scale != 1.0 else None
182
+ shape = [4, width // 8, height // 8]
183
+ # if opt.method == "dpm_solver_v3":
184
+ # batch_size, shape, conditioning, x_T, unconditional_conditioning
185
+ samples, _ = sampler.sample(
186
+ conditioning=c,
187
+ batch_size=1,
188
+ shape=shape,
189
+ unconditional_conditioning=uc,
190
+ x_T=None,
191
+ start_free_u_step=6 if num_inference_steps == 8 else 4,
192
+ xl_preprocess_closure = prepare_sdxl_pipeline_step_parameter,
193
+ # npnet = npn_net,
194
+ use_corrector=True,
195
+ )
196
+
197
+ x_samples = pipe.vae.decode(samples / pipe.vae.config.scaling_factor).sample
198
+ x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
199
+ x_samples = x_samples.cpu().permute(0, 2, 3, 1).numpy()
200
+
201
+ x_image_torch = torch.from_numpy(x_samples).permute(0, 3, 1, 2) # need to pay attention
202
+
203
+ for x_sample in x_image_torch:
204
+ x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
205
+ img = Image.fromarray(x_sample.astype(np.uint8))
206
+ return img
207
+
208
+ @spaces.GPU #[uncomment to use ZeroGPU]
209
+ def infer(
210
+ prompt,
211
+ negative_prompt,
212
+ seed,
213
+ randomize_seed,
214
+ resolution,
215
+ guidance_scale,
216
+ num_inference_steps,
217
+ progress=gr.Progress(track_tqdm=True),
218
+ ):
219
+ if randomize_seed:
220
+ seed = random.randint(0, MAX_SEED)
221
+
222
+ # Parse resolution string into width and height
223
+ width, height = map(int, resolution.split('x'))
224
+
225
+ # Generate image with selected steps
226
+ image_quick = generate_image_with_steps(prompt, negative_prompt, seed, width, height, guidance_scale, num_inference_steps)
227
+
228
+ # Generate image with 50 steps for high quality
229
+ image_50_steps = generate_image_with_steps(prompt, negative_prompt, seed, width, height, guidance_scale, 50)
230
+
231
+ return image_quick, image_50_steps, seed
232
+
233
+
234
+ examples = [
235
+ "Astronaut in a jungle, cold color, muted colors, detailed, 8k",
236
+ "a painting of a virus monster playing guitar",
237
+ "a painting of a squirrel eating a burger",
238
+ ]
239
+
240
+ css = """
241
+ #col-container {
242
+ margin: 0 auto;
243
+ max-width: 640px;
244
+ }
245
+ """
246
+
247
+ with gr.Blocks(css=css) as demo:
248
+ with gr.Column(elem_id="col-container"):
249
+ gr.Markdown(" # Hyperparameters are all you need")
250
+
251
+ with gr.Row():
252
+ prompt = gr.Text(
253
+ label="Prompt",
254
+ show_label=False,
255
+ max_lines=1,
256
+ placeholder="Enter your prompt",
257
+ container=False,
258
+ )
259
+
260
+ run_button = gr.Button("Run", scale=0, variant="primary")
261
+
262
+ with gr.Row():
263
+ with gr.Column():
264
+ gr.Markdown("### Our fast inference Result")
265
+ result = gr.Image(label="Quick Result", show_label=False)
266
+ with gr.Column():
267
+ gr.Markdown("### Original 50 steps Result")
268
+ result_50_steps = gr.Image(label="50 Steps Result", show_label=False)
269
+
270
+ with gr.Accordion("Advanced Settings", open=False):
271
+ negative_prompt = gr.Text(
272
+ label="Negative prompt",
273
+ max_lines=1,
274
+ placeholder="Enter a negative prompt",
275
+ visible=False,
276
+ )
277
+
278
+ seed = gr.Slider(
279
+ label="Seed",
280
+ minimum=0,
281
+ maximum=MAX_SEED,
282
+ step=1,
283
+ value=0,
284
+ )
285
+
286
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
287
+
288
+ resolution = gr.Dropdown(
289
+ choices=[
290
+ "1024x1024",
291
+ "1216x832",
292
+ "832x1216"
293
+ ],
294
+ value="1024x1024",
295
+ label="Resolution",
296
+ )
297
+
298
+ with gr.Row():
299
+ guidance_scale = gr.Slider(
300
+ label="Guidance scale",
301
+ minimum=0.0,
302
+ maximum=10.0,
303
+ step=0.1,
304
+ value=7.5, # Replace with defaults that work for your model
305
+ )
306
+
307
+ num_inference_steps = gr.Dropdown(
308
+ choices=[6, 8],
309
+ value=8,
310
+ label="Number of inference steps",
311
+ )
312
+
313
+ gr.Examples(examples=examples, inputs=[prompt])
314
+ gr.on(
315
+ triggers=[run_button.click, prompt.submit],
316
+ fn=infer,
317
+ inputs=[
318
+ prompt,
319
+ negative_prompt,
320
+ seed,
321
+ randomize_seed,
322
+ resolution,
323
+ guidance_scale,
324
+ num_inference_steps,
325
+ ],
326
+ outputs=[result, result_50_steps, seed],
327
+ )
328
+
329
+ if __name__ == "__main__":
330
+ demo.launch()
dpm_solver_v3.py ADDED
@@ -0,0 +1,904 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+ import numpy as np
5
+ import os
6
+
7
+
8
+ class NoiseScheduleVP:
9
+ def __init__(
10
+ self,
11
+ schedule="discrete",
12
+ betas=None,
13
+ alphas_cumprod=None,
14
+ continuous_beta_0=0.1,
15
+ continuous_beta_1=20.0,
16
+ ):
17
+ """Create a wrapper class for the forward SDE (VP type).
18
+
19
+ ***
20
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
21
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
22
+ ***
23
+
24
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
25
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
26
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
27
+
28
+ log_alpha_t = self.marginal_log_mean_coeff(t)
29
+ sigma_t = self.marginal_std(t)
30
+ lambda_t = self.marginal_lambda(t)
31
+
32
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
33
+
34
+ t = self.inverse_lambda(lambda_t)
35
+
36
+ ===============================================================
37
+
38
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
39
+
40
+ 1. For discrete-time DPMs:
41
+
42
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
43
+ t_i = (i + 1) / N
44
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
45
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
46
+
47
+ Args:
48
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
49
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
50
+
51
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
52
+
53
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
54
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
55
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
56
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
57
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
58
+ and
59
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
60
+
61
+
62
+ 2. For continuous-time DPMs:
63
+
64
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
65
+ schedule are the default settings in DDPM and improved-DDPM:
66
+
67
+ Args:
68
+ beta_min: A `float` number. The smallest beta for the linear schedule.
69
+ beta_max: A `float` number. The largest beta for the linear schedule.
70
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
71
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
72
+ T: A `float` number. The ending time of the forward process.
73
+
74
+ ===============================================================
75
+
76
+ Args:
77
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
78
+ 'linear' or 'cosine' for continuous-time DPMs.
79
+ Returns:
80
+ A wrapper object of the forward SDE (VP type).
81
+
82
+ ===============================================================
83
+
84
+ Example:
85
+
86
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
87
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
88
+
89
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
90
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
91
+
92
+ # For continuous-time DPMs (VPSDE), linear schedule:
93
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
94
+
95
+ """
96
+
97
+ if schedule not in ["discrete", "linear", "cosine"]:
98
+ raise ValueError(
99
+ "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
100
+ schedule
101
+ )
102
+ )
103
+ self.alphas_cumprod = alphas_cumprod
104
+ self.sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
105
+ self.log_sigmas = self.sigmas.log()
106
+ self.schedule = schedule
107
+ if schedule == "discrete":
108
+ if betas is not None:
109
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
110
+ else:
111
+ assert alphas_cumprod is not None
112
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
113
+ self.total_N = len(log_alphas)
114
+ self.T = 1.0
115
+ self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape((1, -1))
116
+ self.log_alpha_array = log_alphas.reshape(
117
+ (
118
+ 1,
119
+ -1,
120
+ )
121
+ )
122
+ else:
123
+ self.total_N = 1000
124
+ self.beta_0 = continuous_beta_0
125
+ self.beta_1 = continuous_beta_1
126
+ self.cosine_s = 0.008
127
+ self.cosine_beta_max = 999.0
128
+ self.cosine_t_max = (
129
+ math.atan(self.cosine_beta_max * (1.0 + self.cosine_s) / math.pi)
130
+ * 2.0
131
+ * (1.0 + self.cosine_s)
132
+ / math.pi
133
+ - self.cosine_s
134
+ )
135
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0))
136
+ self.schedule = schedule
137
+ if schedule == "cosine":
138
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
139
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
140
+ self.T = 0.9946
141
+ else:
142
+ self.T = 1.0
143
+
144
+ def marginal_log_mean_coeff(self, t):
145
+ """
146
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
147
+ """
148
+ if self.schedule == "discrete":
149
+ return interpolate_fn(
150
+ t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)
151
+ ).reshape((-1))
152
+ elif self.schedule == "linear":
153
+ return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
154
+ elif self.schedule == "cosine":
155
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0))
156
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
157
+ return log_alpha_t
158
+
159
+ def sigma_to_t(self, sigma, quantize=None):
160
+ quantize = None
161
+ log_sigma = sigma.log()
162
+ dists = log_sigma - self.log_sigmas[:, None]
163
+ if quantize:
164
+ return dists.abs().argmin(dim=0).view(sigma.shape)
165
+ low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
166
+ high_idx = low_idx + 1
167
+ low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
168
+ w = (low - log_sigma) / (low - high)
169
+ w = w.clamp(0, 1)
170
+ t = (1 - w) * low_idx + w * high_idx
171
+ return t.view(sigma.shape)
172
+
173
+ def get_special_sigmas_with_timesteps(self,timesteps):
174
+ low_idx, high_idx, w = np.minimum(np.floor(timesteps),999), np.minimum(np.ceil(timesteps),999), torch.from_numpy( timesteps - np.floor(timesteps))
175
+ self.alphas_cumprod = self.alphas_cumprod.to('cpu')
176
+ alphas = (1 - w) * self.alphas_cumprod[low_idx] + w * self.alphas_cumprod[high_idx]
177
+ return ((1 - alphas) / alphas) ** 0.5
178
+
179
+ def marginal_alpha(self, t):
180
+ """
181
+ Compute alpha_t of a given continuous-time label t in [0, T].
182
+ """
183
+ return torch.exp(self.marginal_log_mean_coeff(t))
184
+
185
+ def marginal_std(self, t):
186
+ """
187
+ Compute sigma_t of a given continuous-time label t in [0, T].
188
+ """
189
+ return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t)))
190
+
191
+ def marginal_lambda(self, t):
192
+ """
193
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
194
+ """
195
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
196
+ log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff))
197
+ return log_mean_coeff - log_std
198
+
199
+ def inverse_lambda(self, lamb):
200
+ """
201
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
202
+ """
203
+ if self.schedule == "linear":
204
+ tmp = 2.0 * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
205
+ Delta = self.beta_0**2 + tmp
206
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
207
+ elif self.schedule == "discrete":
208
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb)
209
+ t = interpolate_fn(
210
+ log_alpha.reshape((-1, 1)),
211
+ torch.flip(self.log_alpha_array.to(lamb.device), [1]),
212
+ torch.flip(self.t_array.to(lamb.device), [1]),
213
+ )
214
+ return t.reshape((-1,))
215
+ else:
216
+ log_alpha = -0.5 * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
217
+ t_fn = (
218
+ lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0))
219
+ * 2.0
220
+ * (1.0 + self.cosine_s)
221
+ / math.pi
222
+ - self.cosine_s
223
+ )
224
+ t = t_fn(log_alpha)
225
+ return t
226
+
227
+
228
+ def model_wrapper(
229
+ model,
230
+ noise_schedule,
231
+ model_type="noise",
232
+ model_kwargs={},
233
+ guidance_type="uncond",
234
+ condition=None,
235
+ unconditional_condition=None,
236
+ guidance_scale=1.0,
237
+ classifier_fn=None,
238
+ classifier_kwargs={},
239
+ ):
240
+ """Create a wrapper function for the noise prediction model.
241
+
242
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
243
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
244
+
245
+ We support four types of the diffusion model by setting `model_type`:
246
+
247
+ 1. "noise": noise prediction model. (Trained by predicting noise).
248
+
249
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
250
+
251
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
252
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
253
+
254
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
255
+ arXiv preprint arXiv:2202.00512 (2022).
256
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
257
+ arXiv preprint arXiv:2210.02303 (2022).
258
+
259
+ 4. "score": marginal score function. (Trained by denoising score matching).
260
+ Note that the score function and the noise prediction model follows a simple relationship:
261
+ ```
262
+ noise(x_t, t) = -sigma_t * score(x_t, t)
263
+ ```
264
+
265
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
266
+ 1. "uncond": unconditional sampling by DPMs.
267
+ The input `model` has the following format:
268
+ ``
269
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
270
+ ``
271
+
272
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
273
+ The input `model` has the following format:
274
+ ``
275
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
276
+ ``
277
+
278
+ The input `classifier_fn` has the following format:
279
+ ``
280
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
281
+ ``
282
+
283
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
284
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
285
+
286
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
287
+ The input `model` has the following format:
288
+ ``
289
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
290
+ ``
291
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
292
+
293
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
294
+ arXiv preprint arXiv:2207.12598 (2022).
295
+
296
+
297
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
298
+ or continuous-time labels (i.e. epsilon to T).
299
+
300
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
301
+ ``
302
+ def model_fn(x, t_continuous) -> noise:
303
+ t_input = get_model_input_time(t_continuous)
304
+ return noise_pred(model, x, t_input, **model_kwargs)
305
+ ``
306
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
307
+
308
+ ===============================================================
309
+
310
+ Args:
311
+ model: A diffusion model with the corresponding format described above.
312
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
313
+ model_type: A `str`. The parameterization type of the diffusion model.
314
+ "noise" or "x_start" or "v" or "score".
315
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
316
+ guidance_type: A `str`. The type of the guidance for sampling.
317
+ "uncond" or "classifier" or "classifier-free".
318
+ condition: A pytorch tensor. The condition for the guided sampling.
319
+ Only used for "classifier" or "classifier-free" guidance type.
320
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
321
+ Only used for "classifier-free" guidance type.
322
+ guidance_scale: A `float`. The scale for the guided sampling.
323
+ classifier_fn: A classifier function. Only used for the classifier guidance.
324
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
325
+ Returns:
326
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
327
+ """
328
+
329
+ def get_model_input_time(t_continuous):
330
+ """
331
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
332
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
333
+ For continuous-time DPMs, we just use `t_continuous`.
334
+ """
335
+ if noise_schedule.schedule == "discrete":
336
+ return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0
337
+ else:
338
+ return t_continuous
339
+
340
+ def noise_pred_fn(x, t_continuous, cond=None):
341
+ if t_continuous.reshape((-1,)).shape[0] == 1:
342
+ t_continuous = t_continuous.expand((x.shape[0]))
343
+ t_input = get_model_input_time(t_continuous)
344
+ if cond is None:
345
+ output = model(x, t_input, None, **model_kwargs)
346
+ else:
347
+ output = model(x, t_input, cond, **model_kwargs)
348
+ if model_type == "noise":
349
+ return output
350
+ elif model_type == "x_start":
351
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
352
+ dims = x.dim()
353
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
354
+ elif model_type == "v":
355
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
356
+ dims = x.dim()
357
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
358
+ elif model_type == "score":
359
+ sigma_t = noise_schedule.marginal_std(t_continuous)
360
+ dims = x.dim()
361
+ return -expand_dims(sigma_t, dims) * output
362
+
363
+ def cond_grad_fn(x, t_input):
364
+ """
365
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
366
+ """
367
+ with torch.enable_grad():
368
+ x_in = x.detach().requires_grad_(True)
369
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
370
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
371
+
372
+ def model_fn(x, t_continuous):
373
+ """
374
+ The noise predicition model function that is used for DPM-Solver.
375
+ """
376
+ if t_continuous.reshape((-1,)).shape[0] == 1:
377
+ t_continuous = t_continuous.expand((x.shape[0]))
378
+ if guidance_type == "uncond":
379
+ return noise_pred_fn(x, t_continuous)
380
+ elif guidance_type == "classifier":
381
+ assert classifier_fn is not None
382
+ t_input = get_model_input_time(t_continuous)
383
+ cond_grad = cond_grad_fn(x, t_input)
384
+ sigma_t = noise_schedule.marginal_std(t_continuous)
385
+ noise = noise_pred_fn(x, t_continuous)
386
+ return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
387
+ elif guidance_type == "classifier-free":
388
+ if guidance_scale == 1.0 or unconditional_condition is None:
389
+ return noise_pred_fn(x, t_continuous, cond=condition)
390
+ else:
391
+ x_in = torch.cat([x] * 2)
392
+ t_in = torch.cat([t_continuous] * 2)
393
+ if isinstance(condition, torch.Tensor) and ( isinstance(unconditional_condition, torch.Tensor) or unconditional_condition is None ):
394
+ c_in = torch.cat([unconditional_condition, condition])
395
+ else:
396
+ c_in = [condition, unconditional_condition]
397
+ # c_in = torch.cat([unconditional_condition, condition])
398
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
399
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
400
+
401
+ assert model_type in ["noise", "x_start", "v"]
402
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
403
+ return model_fn
404
+
405
+
406
+ def weighted_cumsumexp_trapezoid(a, x, b, cumsum=True):
407
+ # ∫ b*e^a dx
408
+ # Input: a,x,b: shape (N+1,...)
409
+ # Output: y: shape (N+1,...)
410
+ # y_0 = 0
411
+ # y_n = sum_{i=1}^{n} 0.5*(x_{i}-x_{i-1})*(b_{i}*e^{a_{i}}+b_{i-1}*e^{a_{i-1}}) (n from 1 to N)
412
+
413
+ assert x.shape[0] == a.shape[0] and x.ndim == a.ndim
414
+ if b is not None:
415
+ assert a.shape[0] == b.shape[0] and a.ndim == b.ndim
416
+
417
+ a_max = np.amax(a, axis=0, keepdims=True)
418
+
419
+ if b is not None:
420
+ b = np.asarray(b)
421
+ tmp = b * np.exp(a - a_max)
422
+ else:
423
+ tmp = np.exp(a - a_max)
424
+
425
+ out = 0.5 * (x[1:] - x[:-1]) * (tmp[1:] + tmp[:-1])
426
+ if not cumsum:
427
+ return np.sum(out, axis=0) * np.exp(a_max)
428
+ out = np.cumsum(out, axis=0)
429
+ out *= np.exp(a_max)
430
+ return np.concatenate([np.zeros_like(out[[0]]), out], axis=0)
431
+
432
+
433
+ def weighted_cumsumexp_trapezoid_torch(a, x, b, cumsum=True):
434
+ assert x.shape[0] == a.shape[0] and x.ndim == a.ndim
435
+ if b is not None:
436
+ assert a.shape[0] == b.shape[0] and a.ndim == b.ndim
437
+
438
+ a_max = torch.amax(a, dim=0, keepdims=True)
439
+
440
+ if b is not None:
441
+ tmp = b * torch.exp(a - a_max)
442
+ else:
443
+ tmp = torch.exp(a - a_max)
444
+
445
+ out = 0.5 * (x[1:] - x[:-1]) * (tmp[1:] + tmp[:-1])
446
+ if not cumsum:
447
+ return torch.sum(out, dim=0) * torch.exp(a_max)
448
+ out = torch.cumsum(out, dim=0)
449
+ out *= torch.exp(a_max)
450
+ return torch.concat([torch.zeros_like(out[[0]]), out], dim=0)
451
+
452
+
453
+ def index_list(lst, index):
454
+ new_lst = []
455
+ for i in index:
456
+ new_lst.append(lst[i])
457
+ return new_lst
458
+
459
+
460
+ class DPM_Solver_v3:
461
+ def __init__(
462
+ self,
463
+ statistics_dir,
464
+ noise_schedule,
465
+ steps=10,
466
+ t_start=None,
467
+ t_end=None,
468
+ skip_type="time_uniform",
469
+ degenerated=False,
470
+ device="cuda",
471
+ ):
472
+ self.device = device
473
+ self.model = None
474
+ self.noise_schedule = noise_schedule
475
+ self.steps = steps
476
+ t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end
477
+ t_T = self.noise_schedule.T if t_start is None else t_start
478
+ assert (
479
+ t_0 > 0 and t_T > 0
480
+ ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
481
+
482
+ l = np.load(os.path.join(statistics_dir, "l.npz"))["l"]
483
+ sb = np.load(os.path.join(statistics_dir, "sb.npz"))
484
+ s, b = sb["s"], sb["b"]
485
+ if degenerated:
486
+ l = np.ones_like(l)
487
+ s = np.zeros_like(s)
488
+ b = np.zeros_like(b)
489
+ self.statistics_steps = l.shape[0] - 1
490
+ ts = noise_schedule.marginal_lambda(
491
+ self.get_time_steps("logSNR", t_T, t_0, self.statistics_steps, "cpu")
492
+ ).numpy()[:, None, None, None]
493
+ self.ts = torch.from_numpy(ts).cuda()
494
+ self.lambda_T = self.ts[0].cpu().item()
495
+ self.lambda_0 = self.ts[-1].cpu().item()
496
+ z = np.zeros_like(l)
497
+ o = np.ones_like(l)
498
+ L = weighted_cumsumexp_trapezoid(z, ts, l)
499
+ S = weighted_cumsumexp_trapezoid(z, ts, s)
500
+
501
+ I = weighted_cumsumexp_trapezoid(L + S, ts, o)
502
+ B = weighted_cumsumexp_trapezoid(-S, ts, b)
503
+ C = weighted_cumsumexp_trapezoid(L + S, ts, B)
504
+ self.l = torch.from_numpy(l).cuda()
505
+ self.s = torch.from_numpy(s).cuda()
506
+ self.b = torch.from_numpy(b).cuda()
507
+ self.L = torch.from_numpy(L).cuda()
508
+ self.S = torch.from_numpy(S).cuda()
509
+ self.I = torch.from_numpy(I).cuda()
510
+ self.B = torch.from_numpy(B).cuda()
511
+ self.C = torch.from_numpy(C).cuda()
512
+
513
+ # precompute timesteps
514
+ if skip_type == "logSNR" or skip_type == "time_uniform" or skip_type == "time_quadratic" or skip_type == "customed_time_karras":
515
+ self.timesteps = self.get_time_steps(skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
516
+ self.indexes = self.convert_to_indexes(self.timesteps)
517
+ self.timesteps = self.convert_to_timesteps(self.indexes, device)
518
+ elif skip_type == "edm":
519
+ self.indexes, self.timesteps = self.get_timesteps_edm(N=steps, device=device)
520
+ self.timesteps = self.convert_to_timesteps(self.indexes, device)
521
+ else:
522
+ raise ValueError(f"Unsupported timestep strategy {skip_type}")
523
+
524
+ print("Indexes", self.indexes)
525
+ print("Time steps", self.timesteps)
526
+ print("LogSNR steps", self.noise_schedule.marginal_lambda(self.timesteps))
527
+
528
+ # store high-order exponential coefficients (lazy)
529
+ self.exp_coeffs = {}
530
+
531
+ def noise_prediction_fn(self, x, t):
532
+ """
533
+ Return the noise prediction model.
534
+ """
535
+ return self.model(x, t)
536
+
537
+ def convert_to_indexes(self, timesteps):
538
+ logSNR_steps = self.noise_schedule.marginal_lambda(timesteps)
539
+ indexes = list(
540
+ (self.statistics_steps * (logSNR_steps - self.lambda_T) / (self.lambda_0 - self.lambda_T))
541
+ .round()
542
+ .cpu()
543
+ .numpy()
544
+ .astype(np.int64)
545
+ )
546
+ return indexes
547
+
548
+ def convert_to_timesteps(self, indexes, device):
549
+ logSNR_steps = (
550
+ self.lambda_T + (self.lambda_0 - self.lambda_T) * torch.Tensor(indexes).to(device) / self.statistics_steps
551
+ )
552
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
553
+
554
+ def append_zero(self, x):
555
+ return torch.cat([x, x.new_zeros([1])])
556
+
557
+ def get_sigmas_karras(self, n, sigma_min, sigma_max, rho=7., device='cpu', need_append_zero=True):
558
+ """Constructs the noise schedule of Karras et al. (2022)."""
559
+ ramp = torch.linspace(0, 1, n)
560
+ min_inv_rho = sigma_min ** (1 / rho)
561
+ max_inv_rho = sigma_max ** (1 / rho)
562
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
563
+ return self.append_zero(sigmas).to(device) if need_append_zero else sigmas.to(device)
564
+
565
+ def sigma_to_t(self, sigma, quantize=None):
566
+ quantize = False
567
+ log_sigma = sigma.log()
568
+ dists = log_sigma - self.noise_schedule.log_sigmas[:, None]
569
+ if quantize:
570
+ return dists.abs().argmin(dim=0).view(sigma.shape)
571
+ low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.noise_schedule.log_sigmas.shape[0] - 2)
572
+ high_idx = low_idx + 1
573
+ low, high = self.noise_schedule.log_sigmas[low_idx], self.noise_schedule.log_sigmas[high_idx]
574
+ w = (low - log_sigma) / (low - high)
575
+ w = w.clamp(0, 1)
576
+ t = (1 - w) * low_idx + w * high_idx
577
+ return t.view(sigma.shape)
578
+
579
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
580
+ """Compute the intermediate time steps for sampling.
581
+
582
+ Args:
583
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
584
+ - 'logSNR': uniform logSNR for the time steps.
585
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
586
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
587
+ t_T: A `float`. The starting time of the sampling (default is T).
588
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
589
+ N: A `int`. The total number of the spacing of the time steps.
590
+ device: A torch device.
591
+ Returns:
592
+ A pytorch tensor of the time steps, with the shape (N + 1,).
593
+ """
594
+ if skip_type == "logSNR":
595
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
596
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
597
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
598
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
599
+ elif skip_type == "time_uniform":
600
+ return torch.linspace(t_T, t_0, N + 1).to(device)
601
+ elif skip_type == "time_quadratic":
602
+ t_order = 2
603
+ t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device)
604
+ return t
605
+ elif skip_type == "customed_time_karras":
606
+ sigma_T = self.noise_schedule.sigmas[-1].cpu().item()
607
+ sigma_0 = self.noise_schedule.sigmas[0].cpu().item()
608
+ if N == 8:
609
+ sigmas = self.get_sigmas_karras(12, sigma_0, sigma_T, rho=7.0, device=device)
610
+ ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[9])
611
+ ct = self.get_sigmas_karras(9, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
612
+ sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
613
+ real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
614
+ elif N == 5:
615
+ sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0, device=device)
616
+ ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[6])
617
+ ct = self.get_sigmas_karras(6, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
618
+ sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
619
+ real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
620
+ elif N == 6:
621
+ sigmas = self.sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0, device=device)
622
+ ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[6])
623
+ ct = self.get_sigmas_karras(7, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
624
+ sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
625
+ real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
626
+ none_k_ct = torch.from_numpy(np.array(real_ct)).to(device)
627
+ return none_k_ct#real_ct
628
+ else:
629
+ raise ValueError(
630
+ "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)
631
+ )
632
+
633
+ def get_timesteps_edm(self, N, device):
634
+ """Constructs the noise schedule of Karras et al. (2022)."""
635
+
636
+ rho = 7.0 # 7.0 is the value used in the paper
637
+
638
+ sigma_min: float = np.exp(-self.lambda_0)
639
+ sigma_max: float = np.exp(-self.lambda_T)
640
+ ramp = np.linspace(0, 1, N + 1)
641
+ min_inv_rho = sigma_min ** (1 / rho)
642
+ max_inv_rho = sigma_max ** (1 / rho)
643
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
644
+ lambdas = torch.Tensor(-np.log(sigmas)).to(device)
645
+ timesteps = self.noise_schedule.inverse_lambda(lambdas)
646
+
647
+ indexes = list(
648
+ (self.statistics_steps * (lambdas - self.lambda_T) / (self.lambda_0 - self.lambda_T))
649
+ .round()
650
+ .cpu()
651
+ .numpy()
652
+ .astype(np.int64)
653
+ )
654
+ return indexes, timesteps
655
+
656
+ def get_g(self, f_t, i_s, i_t):
657
+ return torch.exp(self.S[i_s] - self.S[i_t]) * f_t - torch.exp(self.S[i_s]) * (self.B[i_t] - self.B[i_s])
658
+
659
+ def compute_exponential_coefficients_high_order(self, i_s, i_t, order=2):
660
+ key = (i_s, i_t, order)
661
+ if key in self.exp_coeffs.keys():
662
+ coeffs = self.exp_coeffs[key]
663
+ else:
664
+ n = order - 1
665
+ a = self.L[i_s : i_t + 1] + self.S[i_s : i_t + 1] - self.L[i_s] - self.S[i_s]
666
+ x = self.ts[i_s : i_t + 1]
667
+ b = (self.ts[i_s : i_t + 1] - self.ts[i_s]) ** n / math.factorial(n)
668
+ coeffs = weighted_cumsumexp_trapezoid_torch(a, x, b, cumsum=False)
669
+ self.exp_coeffs[key] = coeffs
670
+ return coeffs
671
+
672
+ def compute_high_order_derivatives(self, n, lambda_0n, g_0n, pseudo=False):
673
+ # return g^(1), ..., g^(n)
674
+ if pseudo:
675
+ D = [[] for _ in range(n + 1)]
676
+ D[0] = g_0n
677
+ for i in range(1, n + 1):
678
+ for j in range(n - i + 1):
679
+ D[i].append((D[i - 1][j] - D[i - 1][j + 1]) / (lambda_0n[j] - lambda_0n[i + j]))
680
+
681
+ return [D[i][0] * math.factorial(i) for i in range(1, n + 1)]
682
+ else:
683
+ R = []
684
+ for i in range(1, n + 1):
685
+ R.append(torch.pow(lambda_0n[1:] - lambda_0n[0], i))
686
+ R = torch.stack(R).t()
687
+ B = (torch.stack(g_0n[1:]) - g_0n[0]).reshape(n, -1)
688
+ shape = g_0n[0].shape
689
+ solution = torch.linalg.inv(R) @ B
690
+ solution = solution.reshape([n] + list(shape))
691
+ return [solution[i - 1] * math.factorial(i) for i in range(1, n + 1)]
692
+
693
+ def multistep_predictor_update(self, x_lst, eps_lst, time_lst, index_lst, t, i_t, order=1, pseudo=False):
694
+ # x_lst: [..., x_s]
695
+ # eps_lst: [..., eps_s]
696
+ # time_lst: [..., time_s]
697
+ ns = self.noise_schedule
698
+ n = order - 1
699
+ indexes = [-i - 1 for i in range(n + 1)]
700
+ x_0n = index_list(x_lst, indexes)
701
+ eps_0n = index_list(eps_lst, indexes)
702
+ time_0n = torch.FloatTensor(index_list(time_lst, indexes)).cuda()
703
+ index_0n = index_list(index_lst, indexes)
704
+ lambda_0n = ns.marginal_lambda(time_0n)
705
+ alpha_0n = ns.marginal_alpha(time_0n)
706
+ sigma_0n = ns.marginal_std(time_0n)
707
+
708
+ alpha_s, alpha_t = alpha_0n[0], ns.marginal_alpha(t)
709
+ i_s = index_0n[0]
710
+ x_s = x_0n[0]
711
+ g_0n = []
712
+ for i in range(n + 1):
713
+ f_i = (sigma_0n[i] * eps_0n[i] - self.l[index_0n[i]] * x_0n[i]) / alpha_0n[i]
714
+ g_i = self.get_g(f_i, index_0n[0], index_0n[i])
715
+ g_0n.append(g_i)
716
+ g_0 = g_0n[0]
717
+ x_t = (
718
+ alpha_t / alpha_s * torch.exp(self.L[i_s] - self.L[i_t]) * x_s
719
+ - alpha_t * torch.exp(-self.L[i_t] - self.S[i_s]) * (self.I[i_t] - self.I[i_s]) * g_0
720
+ - alpha_t
721
+ * torch.exp(-self.L[i_t])
722
+ * (self.C[i_t] - self.C[i_s] - self.B[i_s] * (self.I[i_t] - self.I[i_s]))
723
+ )
724
+ if order > 1:
725
+ g_d = self.compute_high_order_derivatives(n, lambda_0n, g_0n, pseudo=pseudo)
726
+ for i in range(order - 1):
727
+ x_t = (
728
+ x_t
729
+ - alpha_t
730
+ * torch.exp(self.L[i_s] - self.L[i_t])
731
+ * self.compute_exponential_coefficients_high_order(i_s, i_t, order=i + 2)
732
+ * g_d[i]
733
+ )
734
+ return x_t
735
+
736
+ def multistep_corrector_update(self, x_lst, eps_lst, time_lst, index_lst, order=1, pseudo=False):
737
+ # x_lst: [..., x_s, x_t]
738
+ # eps_lst: [..., eps_s, eps_t]
739
+ # lambda_lst: [..., lambda_s, lambda_t]
740
+ ns = self.noise_schedule
741
+ n = order - 1
742
+ indexes = [-i - 1 for i in range(n + 1)]
743
+ indexes[0] = -2
744
+ indexes[1] = -1
745
+ x_0n = index_list(x_lst, indexes)
746
+ eps_0n = index_list(eps_lst, indexes)
747
+ time_0n = torch.FloatTensor(index_list(time_lst, indexes)).cuda()
748
+ index_0n = index_list(index_lst, indexes)
749
+ lambda_0n = ns.marginal_lambda(time_0n)
750
+ alpha_0n = ns.marginal_alpha(time_0n)
751
+ sigma_0n = ns.marginal_std(time_0n)
752
+
753
+ alpha_s, alpha_t = alpha_0n[0], alpha_0n[1]
754
+ i_s, i_t = index_0n[0], index_0n[1]
755
+ x_s = x_0n[0]
756
+ g_0n = []
757
+ for i in range(n + 1):
758
+ f_i = (sigma_0n[i] * eps_0n[i] - self.l[index_0n[i]] * x_0n[i]) / alpha_0n[i]
759
+ g_i = self.get_g(f_i, index_0n[0], index_0n[i])
760
+ g_0n.append(g_i)
761
+ g_0 = g_0n[0]
762
+ x_t_new = (
763
+ alpha_t / alpha_s * torch.exp(self.L[i_s] - self.L[i_t]) * x_s
764
+ - alpha_t * torch.exp(-self.L[i_t] - self.S[i_s]) * (self.I[i_t] - self.I[i_s]) * g_0
765
+ - alpha_t
766
+ * torch.exp(-self.L[i_t])
767
+ * (self.C[i_t] - self.C[i_s] - self.B[i_s] * (self.I[i_t] - self.I[i_s]))
768
+ )
769
+ if order > 1:
770
+ g_d = self.compute_high_order_derivatives(n, lambda_0n, g_0n, pseudo=pseudo)
771
+ for i in range(order - 1):
772
+ x_t_new = (
773
+ x_t_new
774
+ - alpha_t
775
+ * torch.exp(self.L[i_s] - self.L[i_t])
776
+ * self.compute_exponential_coefficients_high_order(i_s, i_t, order=i + 2)
777
+ * g_d[i]
778
+ )
779
+ return x_t_new
780
+
781
+ def sample(
782
+ self,
783
+ x,
784
+ model_fn,
785
+ order,
786
+ p_pseudo,
787
+ use_corrector,
788
+ c_pseudo,
789
+ lower_order_final,
790
+ start_free_u_step=None,
791
+ free_u_apply_callback=None,
792
+ free_u_stop_callback=None,
793
+ half=False,
794
+ return_intermediate=False,
795
+ ):
796
+ self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
797
+ steps = self.steps
798
+ cached_x = []
799
+ cached_model_output = []
800
+ cached_time = []
801
+ cached_index = []
802
+ indexes, timesteps = self.indexes, self.timesteps
803
+ step_p_order = 0
804
+ if free_u_stop_callback is not None:
805
+ free_u_stop_callback()
806
+ for step in range(1, steps + 1):
807
+ if start_free_u_step is not None and step == start_free_u_step and free_u_apply_callback is not None:
808
+ free_u_apply_callback()
809
+ cached_x.append(x)
810
+ cached_model_output.append(self.noise_prediction_fn(x, timesteps[step - 1]))
811
+ cached_time.append(timesteps[step - 1])
812
+ cached_index.append(indexes[step - 1])
813
+ if use_corrector and (timesteps[step - 1] > 0.5 or not half):
814
+ step_c_order = step_p_order + c_pseudo
815
+ if step_c_order > 1:
816
+ x_new = self.multistep_corrector_update(
817
+ cached_x, cached_model_output, cached_time, cached_index, order=step_c_order, pseudo=c_pseudo
818
+ )
819
+ sigma_t = self.noise_schedule.marginal_std(cached_time[-1])
820
+ l_t = self.l[cached_index[-1]]
821
+ N_old = sigma_t * cached_model_output[-1] - l_t * cached_x[-1]
822
+ cached_x[-1] = x_new
823
+ cached_model_output[-1] = (N_old + l_t * cached_x[-1]) / sigma_t
824
+ if step < order:
825
+ step_p_order = step
826
+ else:
827
+ step_p_order = order
828
+ if lower_order_final:
829
+ step_p_order = min(step_p_order, steps + 1 - step)
830
+ t = timesteps[step]
831
+ i_t = indexes[step]
832
+
833
+ x = self.multistep_predictor_update(
834
+ cached_x, cached_model_output, cached_time, cached_index, t, i_t, order=step_p_order, pseudo=p_pseudo
835
+ )
836
+
837
+ if return_intermediate:
838
+ return x, cached_x
839
+ else:
840
+ return x
841
+
842
+
843
+ #############################################################
844
+ # other utility functions
845
+ #############################################################
846
+
847
+
848
+ def interpolate_fn(x, xp, yp):
849
+ """
850
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
851
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
852
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
853
+
854
+ Args:
855
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
856
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
857
+ yp: PyTorch tensor with shape [C, K].
858
+ Returns:
859
+ The function values f(x), with shape [N, C].
860
+ """
861
+ N, K = x.shape[0], xp.shape[1]
862
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
863
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
864
+ x_idx = torch.argmin(x_indices, dim=2)
865
+ cand_start_idx = x_idx - 1
866
+ start_idx = torch.where(
867
+ torch.eq(x_idx, 0),
868
+ torch.tensor(1, device=x.device),
869
+ torch.where(
870
+ torch.eq(x_idx, K),
871
+ torch.tensor(K - 2, device=x.device),
872
+ cand_start_idx,
873
+ ),
874
+ )
875
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
876
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
877
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
878
+ start_idx2 = torch.where(
879
+ torch.eq(x_idx, 0),
880
+ torch.tensor(0, device=x.device),
881
+ torch.where(
882
+ torch.eq(x_idx, K),
883
+ torch.tensor(K - 2, device=x.device),
884
+ cand_start_idx,
885
+ ),
886
+ )
887
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
888
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
889
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
890
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
891
+ return cand
892
+
893
+
894
+ def expand_dims(v, dims):
895
+ """
896
+ Expand the tensor `v` to the dim `dims`.
897
+
898
+ Args:
899
+ `v`: a PyTorch tensor with shape [N].
900
+ `dim`: a `int`.
901
+ Returns:
902
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
903
+ """
904
+ return v[(...,) + (None,) * (dims - 1)]
free_lunch_utils.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.fft as fft
3
+ from diffusers.utils import is_torch_version
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+
6
+
7
+ def isinstance_str(x: object, cls_name: str):
8
+ """
9
+ Checks whether x has any class *named* cls_name in its ancestry.
10
+ Doesn't require access to the class's implementation.
11
+
12
+ Useful for patching!
13
+ """
14
+
15
+ for _cls in x.__class__.__mro__:
16
+ if _cls.__name__ == cls_name:
17
+ return True
18
+
19
+ return False
20
+
21
+
22
+ def Fourier_filter(x, threshold, scale):
23
+ dtype = x.dtype
24
+ x = x.type(torch.float32)
25
+ # FFT
26
+ x_freq = fft.fftn(x, dim=(-2, -1))
27
+ x_freq = fft.fftshift(x_freq, dim=(-2, -1))
28
+
29
+ B, C, H, W = x_freq.shape
30
+ mask = torch.ones((B, C, H, W)).cuda()
31
+
32
+ crow, ccol = H // 2, W //2
33
+ mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
34
+ x_freq = x_freq * mask
35
+
36
+ # IFFT
37
+ x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
38
+ x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
39
+
40
+ x_filtered = x_filtered.type(dtype)
41
+ return x_filtered
42
+
43
+
44
+ def register_upblock2d(model):
45
+ def up_forward(self):
46
+ def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
47
+ for resnet in self.resnets:
48
+ # pop res hidden states
49
+ res_hidden_states = res_hidden_states_tuple[-1]
50
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
51
+ #print(f"in upblock2d, hidden states shape: {hidden_states.shape}")
52
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
53
+
54
+ if self.training and self.gradient_checkpointing:
55
+
56
+ def create_custom_forward(module):
57
+ def custom_forward(*inputs):
58
+ return module(*inputs)
59
+
60
+ return custom_forward
61
+
62
+ if is_torch_version(">=", "1.11.0"):
63
+ hidden_states = torch.utils.checkpoint.checkpoint(
64
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
65
+ )
66
+ else:
67
+ hidden_states = torch.utils.checkpoint.checkpoint(
68
+ create_custom_forward(resnet), hidden_states, temb
69
+ )
70
+ else:
71
+ hidden_states = resnet(hidden_states, temb)
72
+
73
+ if self.upsamplers is not None:
74
+ for upsampler in self.upsamplers:
75
+ hidden_states = upsampler(hidden_states, upsample_size)
76
+
77
+ return hidden_states
78
+
79
+ return forward
80
+
81
+ for i, upsample_block in enumerate(model.unet.up_blocks):
82
+ if isinstance_str(upsample_block, "UpBlock2D"):
83
+ upsample_block.forward = up_forward(upsample_block)
84
+
85
+
86
+ def register_free_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
87
+ def up_forward(self):
88
+ def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
89
+ for resnet in self.resnets:
90
+ # pop res hidden states
91
+ res_hidden_states = res_hidden_states_tuple[-1]
92
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
93
+ #print(f"in free upblock2d, hidden states shape: {hidden_states.shape}")
94
+
95
+ # --------------- FreeU code -----------------------
96
+ # Only operate on the first two stages
97
+ if hidden_states.shape[1] == 1280:
98
+ hidden_states[:,:640] = hidden_states[:,:640] * self.b1
99
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
100
+ if hidden_states.shape[1] == 640:
101
+ hidden_states[:,:320] = hidden_states[:,:320] * self.b2
102
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
103
+ # ---------------------------------------------------------
104
+
105
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
106
+
107
+ if self.training and self.gradient_checkpointing:
108
+
109
+ def create_custom_forward(module):
110
+ def custom_forward(*inputs):
111
+ return module(*inputs)
112
+
113
+ return custom_forward
114
+
115
+ if is_torch_version(">=", "1.11.0"):
116
+ hidden_states = torch.utils.checkpoint.checkpoint(
117
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
118
+ )
119
+ else:
120
+ hidden_states = torch.utils.checkpoint.checkpoint(
121
+ create_custom_forward(resnet), hidden_states, temb
122
+ )
123
+ else:
124
+ hidden_states = resnet(hidden_states, temb)
125
+
126
+ if self.upsamplers is not None:
127
+ for upsampler in self.upsamplers:
128
+ hidden_states = upsampler(hidden_states, upsample_size)
129
+
130
+ return hidden_states
131
+
132
+ return forward
133
+
134
+ for i, upsample_block in enumerate(model.unet.up_blocks):
135
+ if isinstance_str(upsample_block, "UpBlock2D"):
136
+ upsample_block.forward = up_forward(upsample_block)
137
+ setattr(upsample_block, 'b1', b1)
138
+ setattr(upsample_block, 'b2', b2)
139
+ setattr(upsample_block, 's1', s1)
140
+ setattr(upsample_block, 's2', s2)
141
+
142
+
143
+ def register_crossattn_upblock2d(model):
144
+ def up_forward(self):
145
+ def forward(
146
+ hidden_states: torch.FloatTensor,
147
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
148
+ temb: Optional[torch.FloatTensor] = None,
149
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
150
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
151
+ upsample_size: Optional[int] = None,
152
+ attention_mask: Optional[torch.FloatTensor] = None,
153
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
154
+ ):
155
+ for resnet, attn in zip(self.resnets, self.attentions):
156
+ # pop res hidden states
157
+ #print(f"in crossatten upblock2d, hidden states shape: {hidden_states.shape}")
158
+ res_hidden_states = res_hidden_states_tuple[-1]
159
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
160
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
161
+
162
+ if self.training and self.gradient_checkpointing:
163
+
164
+ def create_custom_forward(module, return_dict=None):
165
+ def custom_forward(*inputs):
166
+ if return_dict is not None:
167
+ return module(*inputs, return_dict=return_dict)
168
+ else:
169
+ return module(*inputs)
170
+
171
+ return custom_forward
172
+
173
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
174
+ hidden_states = torch.utils.checkpoint.checkpoint(
175
+ create_custom_forward(resnet),
176
+ hidden_states,
177
+ temb,
178
+ **ckpt_kwargs,
179
+ )
180
+ hidden_states = torch.utils.checkpoint.checkpoint(
181
+ create_custom_forward(attn, return_dict=False),
182
+ hidden_states,
183
+ encoder_hidden_states,
184
+ None, # timestep
185
+ None, # class_labels
186
+ cross_attention_kwargs,
187
+ attention_mask,
188
+ encoder_attention_mask,
189
+ **ckpt_kwargs,
190
+ )[0]
191
+ else:
192
+ hidden_states = resnet(hidden_states, temb)
193
+ hidden_states = attn(
194
+ hidden_states,
195
+ encoder_hidden_states=encoder_hidden_states,
196
+ cross_attention_kwargs=cross_attention_kwargs,
197
+ attention_mask=attention_mask,
198
+ encoder_attention_mask=encoder_attention_mask,
199
+ return_dict=False,
200
+ )[0]
201
+
202
+ if self.upsamplers is not None:
203
+ for upsampler in self.upsamplers:
204
+ hidden_states = upsampler(hidden_states, upsample_size)
205
+
206
+ return hidden_states
207
+
208
+ return forward
209
+
210
+ for i, upsample_block in enumerate(model.unet.up_blocks):
211
+ if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
212
+ upsample_block.forward = up_forward(upsample_block)
213
+
214
+
215
+ def register_free_crossattn_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
216
+ def up_forward(self):
217
+ def forward(
218
+ hidden_states: torch.FloatTensor,
219
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
220
+ temb: Optional[torch.FloatTensor] = None,
221
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
222
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
223
+ upsample_size: Optional[int] = None,
224
+ attention_mask: Optional[torch.FloatTensor] = None,
225
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
226
+ ):
227
+ for resnet, attn in zip(self.resnets, self.attentions):
228
+ # pop res hidden states
229
+ #print(f"in free crossatten upblock2d, hidden states shape: {hidden_states.shape}")
230
+ res_hidden_states = res_hidden_states_tuple[-1]
231
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
232
+
233
+ # --------------- FreeU code -----------------------
234
+ # Only operate on the first two stages
235
+ if hidden_states.shape[1] == 1280:
236
+ hidden_states[:,:640] = hidden_states[:,:640] * self.b1
237
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
238
+ if hidden_states.shape[1] == 640:
239
+ hidden_states[:,:320] = hidden_states[:,:320] * self.b2
240
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
241
+ # ---------------------------------------------------------
242
+
243
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
244
+
245
+ if self.training and self.gradient_checkpointing:
246
+
247
+ def create_custom_forward(module, return_dict=None):
248
+ def custom_forward(*inputs):
249
+ if return_dict is not None:
250
+ return module(*inputs, return_dict=return_dict)
251
+ else:
252
+ return module(*inputs)
253
+
254
+ return custom_forward
255
+
256
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
257
+ hidden_states = torch.utils.checkpoint.checkpoint(
258
+ create_custom_forward(resnet),
259
+ hidden_states,
260
+ temb,
261
+ **ckpt_kwargs,
262
+ )
263
+ hidden_states = torch.utils.checkpoint.checkpoint(
264
+ create_custom_forward(attn, return_dict=False),
265
+ hidden_states,
266
+ encoder_hidden_states,
267
+ None, # timestep
268
+ None, # class_labels
269
+ cross_attention_kwargs,
270
+ attention_mask,
271
+ encoder_attention_mask,
272
+ **ckpt_kwargs,
273
+ )[0]
274
+ else:
275
+ hidden_states = resnet(hidden_states, temb)
276
+ # hidden_states = attn(
277
+ # hidden_states,
278
+ # encoder_hidden_states=encoder_hidden_states,
279
+ # cross_attention_kwargs=cross_attention_kwargs,
280
+ # encoder_attention_mask=encoder_attention_mask,
281
+ # return_dict=False,
282
+ # )[0]
283
+ hidden_states = attn(
284
+ hidden_states,
285
+ encoder_hidden_states=encoder_hidden_states,
286
+ cross_attention_kwargs=cross_attention_kwargs,
287
+ )[0]
288
+
289
+ if self.upsamplers is not None:
290
+ for upsampler in self.upsamplers:
291
+ hidden_states = upsampler(hidden_states, upsample_size)
292
+
293
+ return hidden_states
294
+
295
+ return forward
296
+
297
+ for i, upsample_block in enumerate(model.unet.up_blocks):
298
+ if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
299
+ upsample_block.forward = up_forward(upsample_block)
300
+ setattr(upsample_block, 'b1', b1)
301
+ setattr(upsample_block, 'b2', b2)
302
+ setattr(upsample_block, 's1', s1)
303
+ setattr(upsample_block, 's2', s2)
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tqdm
2
+ einops
3
+ pytorch_lightning
4
+ accelerate
5
+ torchsde
6
+ pycocotools
7
+ diffusers
8
+ timm
9
+ transformers
10
+ opencv-python
11
+ omegaconf
sampler.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+
5
+ from dpm_solver_v3 import NoiseScheduleVP, model_wrapper, DPM_Solver_v3
6
+ from uni_pc import UniPC
7
+ from free_lunch_utils import register_free_upblock2d, register_free_crossattn_upblock2d
8
+
9
+
10
+ class DPMSolverv3Sampler:
11
+ def __init__(self, stats_dir, pipe, steps, guidance_scale, **kwargs):
12
+ super().__init__()
13
+ self.model = pipe
14
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(pipe.device)
15
+ DTYPE = torch.float32 # torch.float16 works as well, but pictures seem to be a bit worse
16
+ device = "cuda"
17
+ noise_scheduler = pipe.scheduler
18
+ alpha_schedule = noise_scheduler.alphas_cumprod.to(device=device, dtype=DTYPE)
19
+ self.alphas_cumprod = alpha_schedule #to_torch(model.alphas_cumprod)
20
+ self.device = device
21
+ self.guidance_scale = guidance_scale
22
+
23
+ self.ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod)
24
+
25
+ assert stats_dir is not None, f"No statistics file found in {stats_dir}."
26
+ print("Use statistics", stats_dir)
27
+ self.dpm_solver_v3 = DPM_Solver_v3(
28
+ statistics_dir=stats_dir,
29
+ noise_schedule=self.ns,
30
+ steps=steps,
31
+ t_start=None,
32
+ t_end=None,
33
+ skip_type="customed_time_karras",
34
+ degenerated=False,
35
+ device=self.device,
36
+ )
37
+ self.steps = steps
38
+
39
+ @torch.no_grad()
40
+ def apply_free_unet(self):
41
+ register_free_upblock2d(self.model, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
42
+ register_free_crossattn_upblock2d(self.model, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
43
+
44
+ @torch.no_grad()
45
+ def stop_free_unet(self):
46
+ register_free_upblock2d(self.model, b1=1.0, b2=1.0, s1=1.0, s2=1.0)
47
+ register_free_crossattn_upblock2d(self.model, b1=1.0, b2=1.0, s1=1.0, s2=1.0)
48
+
49
+ @torch.no_grad()
50
+ def sample(
51
+ self,
52
+ batch_size,
53
+ shape,
54
+ conditioning=None,
55
+ x_T=None,
56
+ unconditional_conditioning=None,
57
+ use_corrector=False,
58
+ half=False,
59
+ start_free_u_step=None,
60
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
61
+ **kwargs,
62
+ ):
63
+ if conditioning is not None:
64
+ cond_in = torch.cat([unconditional_conditioning, conditioning])
65
+ # extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': self.guidance_scale}
66
+ if isinstance(conditioning, dict):
67
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
68
+ if cbs != batch_size:
69
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
70
+ else:
71
+ if conditioning.shape[0] != batch_size:
72
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
73
+
74
+ # sampling
75
+ C, H, W = shape
76
+ size = (batch_size, C, H, W)
77
+
78
+ if x_T is None:
79
+ img = torch.randn(size, device=self.device)
80
+ else:
81
+ img = x_T
82
+
83
+ if conditioning is None:
84
+ model_fn = model_wrapper(
85
+ lambda x, t, c: self.model.unet(x, t, encoder_hidden_states=c).sample,
86
+ self.ns,
87
+ model_type="noise",
88
+ guidance_type="uncond",
89
+ )
90
+ ORDER = 3
91
+ else:
92
+ model_fn = model_wrapper(
93
+ lambda x, t, c: self.model.unet(x, t, encoder_hidden_states=c).sample,
94
+ self.ns,
95
+ model_type="noise",
96
+ guidance_type="classifier-free",
97
+ condition=conditioning,
98
+ unconditional_condition=unconditional_conditioning,
99
+ guidance_scale=self.guidance_scale,
100
+ )
101
+ if self.steps == 8:
102
+ ORDER = 2
103
+ else:
104
+ ORDER = 1
105
+
106
+ x = self.dpm_solver_v3.sample(
107
+ img,
108
+ model_fn,
109
+ order=ORDER,
110
+ p_pseudo=False,
111
+ c_pseudo=True,
112
+ lower_order_final=True,
113
+ use_corrector=use_corrector,
114
+ start_free_u_step=start_free_u_step,
115
+ free_u_apply_callback=self.apply_free_unet if start_free_u_step is not None else None,
116
+ free_u_stop_callback=self.stop_free_unet if start_free_u_step is not None else None,
117
+ half=half,
118
+ )
119
+
120
+ return x.to(self.device), None
121
+
122
+
123
+ class UniPCSampler:
124
+ def __init__(self
125
+ , pipe
126
+ , model_closure
127
+ , steps
128
+ , guidance_scale,denoise_to_zero=False
129
+ , need_fp16_discrete_method = False
130
+ , ultilize_vae_in_fp16 = False
131
+ , is_high_resoulution = True
132
+ , skip_type="customed_time_karras"
133
+ , force_not_use_afs=False
134
+ , **kwargs):
135
+ super().__init__()
136
+ # self.model = pipe
137
+ self.model = model_closure(pipe)
138
+ self.pipe = pipe
139
+ self.need_fp16_discrete_method = need_fp16_discrete_method
140
+ # to_torch = lambda x: x.clone().detach().to(torch.float32).to(pipe.device)
141
+ DTYPE = self.pipe.unet.dtype # torch.float16 works as well, but pictures seem to be a bit worse
142
+ device = self.pipe.device
143
+ noise_scheduler = pipe.scheduler
144
+ alpha_schedule = noise_scheduler.alphas_cumprod.to(device=device, dtype=DTYPE)
145
+ self.alphas_cumprod = alpha_schedule #to_torch(model.alphas_cumprod)
146
+ self.device = device
147
+ self.guidance_scale = guidance_scale
148
+ self.use_afs = steps <= 8 and is_high_resoulution and not force_not_use_afs
149
+
150
+ self.ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod)
151
+
152
+ self.unipc_solver = UniPC(
153
+ noise_schedule=self.ns,
154
+ steps=steps,
155
+ t_start=None,
156
+ t_end=None,
157
+ skip_type=skip_type,
158
+ degenerated=False,
159
+ use_afs=self.use_afs,
160
+ device=self.device,
161
+ denoise_to_zero=denoise_to_zero,
162
+ need_fp16_discrete_method = self.need_fp16_discrete_method,
163
+ ultilize_vae_in_fp16 = ultilize_vae_in_fp16,
164
+ is_high_resoulution=is_high_resoulution,
165
+ )
166
+ self.steps = steps
167
+
168
+ @torch.no_grad()
169
+ def apply_free_unet(self):
170
+ register_free_upblock2d(self.pipe, b1=1.2, b2=1.2, s1=0.9, s2=0.2)
171
+ register_free_crossattn_upblock2d(self.pipe, b1=1.2, b2=1.2, s1=0.9, s2=0.2)
172
+
173
+ @torch.no_grad()
174
+ def stop_free_unet(self):
175
+ register_free_upblock2d(self.pipe, b1=1.0, b2=1.0, s1=1.0, s2=1.0)
176
+ register_free_crossattn_upblock2d(self.pipe, b1=1.0, b2=1.0, s1=1.0, s2=1.0)
177
+
178
+ @torch.no_grad()
179
+ def sample(
180
+ self,
181
+ batch_size,
182
+ shape,
183
+ conditioning=None,
184
+ x_T=None,
185
+ unconditional_conditioning=None,
186
+ use_corrector=False,
187
+ half=False,
188
+ start_free_u_step=None,
189
+ xl_preprocess_closure=None,
190
+ npnet=None,
191
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
192
+ **kwargs,
193
+ ):
194
+
195
+ # sampling
196
+ C, H, W = shape
197
+ size = (batch_size, C, H, W)
198
+ new_img = None
199
+ if xl_preprocess_closure is not None:
200
+ prompt_embeds, cond_kwargs = xl_preprocess_closure(pipe=self.pipe,prompts = conditioning, need_cfg=True, device=self.device,negative_prompts=unconditional_conditioning)
201
+ if x_T is None:
202
+ img = torch.randn(size, device=self.device)
203
+ else:
204
+ img = x_T
205
+ if xl_preprocess_closure is not None and npnet is not None:
206
+ c, _ = prompt_embeds
207
+ c = c.unsqueeze(0) # add dummy dimension for npnet
208
+ new_img = npnet(img, c)
209
+
210
+ if conditioning is None:
211
+ model_fn = model_wrapper(
212
+ lambda x, t, c: self.model(x, t, c),
213
+ self.ns,
214
+ model_type="noise",
215
+ guidance_type="uncond",
216
+ )
217
+ ORDER = 3
218
+ else:
219
+ model_fn = model_wrapper(
220
+ lambda x, t, c: self.model(x, t, c),
221
+ self.ns,
222
+ model_type="noise",
223
+ guidance_type="classifier-free",
224
+ condition=conditioning if xl_preprocess_closure is None else prompt_embeds,
225
+ unconditional_condition=unconditional_conditioning if xl_preprocess_closure is None else cond_kwargs,
226
+ guidance_scale=self.guidance_scale,
227
+ )
228
+ if self.steps >= 7:
229
+ ORDER = 2
230
+ else:
231
+ ORDER = 1
232
+
233
+ x, full_cache = self.unipc_solver.sample(
234
+ x=img,
235
+ model_fn=model_fn,
236
+ order=ORDER,
237
+ use_corrector=use_corrector,
238
+ lower_order_final=True,
239
+ start_free_u_step=start_free_u_step,
240
+ free_u_apply_callback=self.apply_free_unet if start_free_u_step is not None else None,
241
+ free_u_stop_callback=self.stop_free_unet if start_free_u_step is not None else None,
242
+ npnet_x=new_img if new_img is not None else None,
243
+ npnet_scale=self.guidance_scale if new_img is not None else None,
244
+ half=half,
245
+ )
246
+
247
+ return x.to(self.device), full_cache
248
+
249
+ @torch.no_grad()
250
+ def sample_mix(
251
+ self,
252
+ batch_size,
253
+ shape,
254
+ conditioning=None,
255
+ x_T=None,
256
+ unconditional_conditioning=None,
257
+ use_corrector=False,
258
+ half=False,
259
+ start_free_u_step=None,
260
+ xl_preprocess_closure=None,
261
+ npnet=None,
262
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
263
+ **kwargs,
264
+ ):
265
+
266
+ # sampling
267
+ C, H, W = shape
268
+ size = (batch_size, C, H, W)
269
+ if xl_preprocess_closure is not None:
270
+ prompt_embeds, cond_kwargs = xl_preprocess_closure(pipe=self.pipe,prompts = conditioning, need_cfg=True, device=self.device,negative_prompts=unconditional_conditioning)
271
+ if x_T is None:
272
+ img = torch.randn(size, device=self.device)
273
+ else:
274
+ img = x_T
275
+ if xl_preprocess_closure is not None and npnet is not None:
276
+ c, _ = prompt_embeds
277
+ c = c.unsqueeze(0) # add dummy dimension for npnet
278
+ img = npnet(img, c)
279
+
280
+ if conditioning is None:
281
+ model_fn = model_wrapper(
282
+ lambda x, t, c: self.model(x, t, c),
283
+ self.ns,
284
+ model_type="noise",
285
+ guidance_type="uncond",
286
+ )
287
+ ORDER = 3
288
+ else:
289
+ model_fn = model_wrapper(
290
+ lambda x, t, c: self.model(x, t, c),
291
+ self.ns,
292
+ model_type="noise",
293
+ guidance_type="classifier-free",
294
+ condition=conditioning if xl_preprocess_closure is None else prompt_embeds,
295
+ unconditional_condition=unconditional_conditioning if xl_preprocess_closure is None else cond_kwargs,
296
+ guidance_scale=self.guidance_scale,
297
+ )
298
+ if self.steps >= 8 and not self.need_fp16_discrete_method:
299
+ ORDER = 2
300
+ else:
301
+ ORDER = 1
302
+
303
+ x, full_cache = self.unipc_solver.sample_mix(
304
+ x=img,
305
+ model_fn=model_fn,
306
+ order=ORDER,
307
+ use_corrector=use_corrector,
308
+ lower_order_final=True,
309
+ start_free_u_step=start_free_u_step,
310
+ free_u_apply_callback=self.apply_free_unet if start_free_u_step is not None else None,
311
+ free_u_stop_callback=self.stop_free_unet if start_free_u_step is not None else None,
312
+ half=half,
313
+ )
314
+
315
+ return x.to(self.device), full_cache
uni_pc.py ADDED
@@ -0,0 +1,757 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dpm_solver_v3 import NoiseScheduleVP, model_wrapper
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import math
5
+ import numpy as np
6
+ import os
7
+
8
+ class UniPC:
9
+ def __init__(
10
+ self,
11
+ noise_schedule,
12
+ steps=10,
13
+ t_start=None,
14
+ t_end=None,
15
+ skip_type="customed_time_karras",
16
+ degenerated=False,
17
+ use_afs = False,
18
+ denoise_to_zero=False,
19
+ need_fp16_discrete_method = False,
20
+ ultilize_vae_in_fp16 = False,
21
+ is_high_resoulution = True,
22
+ device="cuda",
23
+ ):
24
+ self.device = device
25
+ self.model = None
26
+ self.noise_schedule = noise_schedule
27
+ self.steps = steps if not use_afs else steps + 1
28
+ self.use_afs = use_afs
29
+ self.ultilize_vae_in_fp16 = ultilize_vae_in_fp16
30
+ self.need_fp16_discrete_method = need_fp16_discrete_method
31
+ t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end
32
+ t_T = self.noise_schedule.T if t_start is None else t_start
33
+ self.is_high_resolution = is_high_resoulution
34
+ assert (
35
+ t_0 > 0 and t_T > 0
36
+ ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
37
+
38
+
39
+ # precompute timesteps
40
+ if skip_type == "logSNR" or skip_type == "time_uniform" or skip_type == "time_quadratic" or skip_type == "customed_time_karras":
41
+ self.timesteps = self.get_time_steps(skip_type
42
+ , t_T=t_T
43
+ , t_0=t_0
44
+ , N=steps
45
+ , device=device,denoise_to_zero=denoise_to_zero
46
+ , is_high_resolution=self.is_high_resolution)
47
+ else:
48
+ raise ValueError(f"Unsupported timestep strategy {skip_type}")
49
+ self.lambda_T = self.timesteps[0].cpu().item()
50
+ self.lambda_0 = self.timesteps[-1].cpu().item()
51
+
52
+ # print("Time steps", self.timesteps)
53
+ # print("LogSNR steps", self.noise_schedule.marginal_lambda(self.timesteps))
54
+
55
+ # store high-order exponential coefficients (lazy)
56
+ self.exp_coeffs = {}
57
+
58
+ def noise_prediction_fn(self, x, t):
59
+ """
60
+ Return the noise prediction model.
61
+ """
62
+ return self.model(x, t)
63
+
64
+ def append_zero(self, x):
65
+ return torch.cat([x, x.new_zeros([1])])
66
+
67
+ def get_sigmas_karras(self, n, sigma_min, sigma_max, rho=7., device='cpu', need_append_zero=True):
68
+ """Constructs the noise schedule of Karras et al. (2022)."""
69
+ ramp = torch.linspace(0, 1, n)
70
+ min_inv_rho = sigma_min ** (1 / rho)
71
+ max_inv_rho = sigma_max ** (1 / rho)
72
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
73
+ return self.append_zero(sigmas).to(device) if need_append_zero else sigmas.to(device)
74
+
75
+ def sigma_to_t(self, sigma, quantize=None):
76
+ quantize = False
77
+ log_sigma = sigma.log()
78
+ dists = log_sigma - self.noise_schedule.log_sigmas[:, None]
79
+ if quantize:
80
+ return dists.abs().argmin(dim=0).view(sigma.shape)
81
+ low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.noise_schedule.log_sigmas.shape[0] - 2)
82
+ high_idx = low_idx + 1
83
+ low, high = self.noise_schedule.log_sigmas[low_idx], self.noise_schedule.log_sigmas[high_idx]
84
+ w = (low - log_sigma) / (low - high)
85
+ w = w.clamp(0, 1)
86
+ t = (1 - w) * low_idx + w * high_idx
87
+ return t.view(sigma.shape)
88
+
89
+ def get_time_steps(self, skip_type, t_T, t_0, N, device, denoise_to_zero=False, is_high_resolution=True):
90
+ """Compute the intermediate time steps for sampling.
91
+
92
+ Args:
93
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
94
+ - 'logSNR': uniform logSNR for the time steps.
95
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
96
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
97
+ t_T: A `float`. The starting time of the sampling (default is T).
98
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
99
+ N: A `int`. The total number of the spacing of the time steps.
100
+ device: A torch device.
101
+ Returns:
102
+ A pytorch tensor of the time steps, with the shape (N + 1,).
103
+ """
104
+ if skip_type == "logSNR":
105
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
106
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
107
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
108
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
109
+ elif skip_type == "time_uniform":
110
+ return torch.linspace(t_T, t_0, N + 1).to(device)
111
+ elif skip_type == "time_quadratic":
112
+ t_order = 2
113
+ t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device)
114
+ return t
115
+ elif skip_type == "customed_time_karras" and is_high_resolution:
116
+ sigma_T = self.noise_schedule.sigmas[-1].cpu().item()
117
+ sigma_0 = self.noise_schedule.sigmas[0].cpu().item()
118
+ if N == 8:
119
+ sigmas = self.get_sigmas_karras(12, sigma_0, sigma_T,rho=12.0, device=device)
120
+ if not self.need_fp16_discrete_method:
121
+ ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[10])
122
+ ct = self.get_sigmas_karras(9, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
123
+ sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
124
+ real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
125
+ else:
126
+ sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0, device=device)
127
+ ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[6])
128
+ ct = self.get_sigmas_karras(8, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
129
+ sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
130
+ tmp_t = [self.noise_schedule.sigma_to_t(sigma).to('cpu') for sigma in sigmas_ct]
131
+ real_ct = [ t / 999 for t in tmp_t]
132
+ elif N == 5:
133
+ sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0, device=device)
134
+ if not self.need_fp16_discrete_method:
135
+ sigmas = self.get_sigmas_karras(12, sigma_0, sigma_T,rho=12.0, device=device)
136
+ ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[9])
137
+ ct = self.get_sigmas_karras(6, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
138
+ sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
139
+ real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
140
+ else:
141
+ ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[6])
142
+ ct = self.get_sigmas_karras(5, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
143
+ sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
144
+ real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
145
+ elif N == 6:
146
+ sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0, device=device)
147
+ if not self.need_fp16_discrete_method:
148
+ sigmas = self.get_sigmas_karras(12, sigma_0, sigma_T,rho=12.0, device=device)
149
+ ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[10])
150
+ ct = self.get_sigmas_karras(7, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
151
+ sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
152
+ real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
153
+ else:
154
+ if denoise_to_zero:
155
+ ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[6])
156
+ ct = self.get_sigmas_karras(6, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
157
+ sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
158
+ real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
159
+ real_ct.append(torch.tensor(t_0).to(dtype=real_ct[-1].dtype,device='cpu'))
160
+ else:
161
+ sigmas = self.get_sigmas_karras(12, sigma_0, sigma_T, rho=7.0, device=device)
162
+ ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[7])
163
+ ct = self.get_sigmas_karras(7, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
164
+ sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
165
+ real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
166
+ elif N == 7:
167
+ sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0, device=device)
168
+ if not self.need_fp16_discrete_method:
169
+ ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[6])
170
+ ct = self.get_sigmas_karras(8, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
171
+ sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
172
+ real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
173
+ else:
174
+ if denoise_to_zero:
175
+ ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[6])
176
+ ct = self.get_sigmas_karras(7, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
177
+ sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
178
+ real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
179
+ real_ct.append(torch.tensor(t_0).to(dtype=real_ct[-1].dtype,device='cpu'))
180
+ # if denoise_to_zero:
181
+ # real_ct.append(torch.tensor(t_0).to(dtype=real_ct[-1].dtype,device='cpu'))
182
+
183
+ if self.use_afs:
184
+ tmp_t = (real_ct[0] + real_ct[1]) / 2
185
+ real_ct.insert(1, tmp_t)
186
+ none_k_ct = torch.from_numpy(np.array(real_ct)).to(device)
187
+ return none_k_ct#real_ct
188
+ elif skip_type == "customed_time_karras" and not is_high_resolution:
189
+ sigma_T = self.noise_schedule.sigmas[-1].cpu().item()
190
+ sigma_0 = self.noise_schedule.sigmas[0].cpu().item()
191
+ if N == 8:
192
+ sigmas = self.get_sigmas_karras(12, sigma_0, sigma_T, rho=7.0, device=device)
193
+ ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[9])
194
+ ct = self.get_sigmas_karras(9, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
195
+ sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
196
+ real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
197
+ elif N == 5:
198
+ sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0, device=device)
199
+ ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[6])
200
+ ct = self.get_sigmas_karras(6, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
201
+ sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
202
+ real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
203
+ elif N == 6:
204
+ sigmas = self.sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0, device=device)
205
+ ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[6])
206
+ ct = self.get_sigmas_karras(7, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
207
+ sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
208
+ real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
209
+ none_k_ct = torch.from_numpy(np.array(real_ct)).to(device)
210
+ return none_k_ct#real_ct
211
+ else:
212
+ raise ValueError(
213
+ "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)
214
+ )
215
+
216
+
217
+ def multistep_uni_pc_update(self, x, model_prev_list:list, t_prev_list: list, t, order, **kwargs):
218
+ if len(model_prev_list) == 0 or len(t_prev_list) == 0:
219
+ return None, None
220
+ if len(t.shape) == 0:
221
+ t = t.view(-1)
222
+ if True:#'bh' in self.variant:
223
+ return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
224
+ else:
225
+ # assert self.variant == 'vary_coeff'
226
+ return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
227
+
228
+ def multistep_uni_pc_sde_update(self, x, model_prev_list:list, t_prev_list: list, t, order, level = 1.0, **kwargs):
229
+ if len(model_prev_list) == 0 or len(t_prev_list) == 0:
230
+ return None, None
231
+ if len(t.shape) == 0:
232
+ t = t.view(-1)
233
+ if True:#'bh' in self.variant:
234
+ return self.multistep_uni_pc_bh_sde_update(x, model_prev_list, t_prev_list, t, level=level, order= order, **kwargs)
235
+ else:
236
+ # assert self.variant == 'vary_coeff'
237
+ return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
238
+
239
+ def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
240
+ # print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
241
+ ns = self.noise_schedule
242
+ assert order <= len(model_prev_list)
243
+ dims = x.dim()
244
+
245
+ # first compute rks
246
+ t_prev_0 = t_prev_list[-1]
247
+ lambda_prev_0 = ns.marginal_lambda(t_prev_0)
248
+ lambda_t = ns.marginal_lambda(t)
249
+ model_prev_0 = model_prev_list[-1]
250
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
251
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
252
+ alpha_t = torch.exp(log_alpha_t)
253
+
254
+ h = lambda_t - lambda_prev_0
255
+
256
+ rks = []
257
+ D1s = []
258
+ for i in range(1, order):
259
+ t_prev_i = t_prev_list[-(i + 1)]
260
+ model_prev_i = model_prev_list[-(i + 1)]
261
+ lambda_prev_i = ns.marginal_lambda(t_prev_i)
262
+ rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
263
+ rks.append(rk)
264
+ D1s.append((model_prev_i - model_prev_0) / rk)
265
+
266
+ rks.append(1.)
267
+ rks = torch.tensor(rks, device=x.device)
268
+
269
+ R = []
270
+ b = []
271
+
272
+ hh = h[0]
273
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
274
+ h_phi_k = h_phi_1 / hh - 1
275
+
276
+ factorial_i = 1
277
+
278
+ if True:
279
+ B_h = hh
280
+ else:
281
+ B_h = torch.expm1(hh)
282
+
283
+ for i in range(1, order + 1):
284
+ R.append(torch.pow(rks, i - 1))
285
+ b.append(h_phi_k * factorial_i / B_h)
286
+ factorial_i *= (i + 1)
287
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
288
+
289
+ R = torch.stack(R)
290
+ b = torch.tensor(b, device=x.device)
291
+
292
+ # now predictor
293
+ use_predictor = len(D1s) > 0 and x_t is None
294
+ if len(D1s) > 0:
295
+ D1s = torch.stack(D1s, dim=1) # (B, K)
296
+ if x_t is None:
297
+ # for order 2, we use a simplified version
298
+ if order == 2:
299
+ rhos_p = torch.tensor([0.5], device=b.device)
300
+ else:
301
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
302
+ else:
303
+ D1s = None
304
+
305
+ if use_corrector:
306
+ # print('using corrector')
307
+ # for order 1, we use a simplified version
308
+ if order == 1:
309
+ rhos_c = torch.tensor([0.5], device=b.device)
310
+ else:
311
+ rhos_c = torch.linalg.solve(R, b)
312
+
313
+ model_t = None
314
+
315
+ x_t_ = (
316
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
317
+ - expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
318
+ )
319
+ if x_t is None:
320
+ if use_predictor:
321
+ pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
322
+ else:
323
+ pred_res = 0
324
+ x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res
325
+
326
+ if use_corrector:
327
+ model_t = self.noise_prediction_fn(x_t, t)
328
+ if D1s is not None:
329
+ corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
330
+ else:
331
+ corr_res = 0
332
+ D1_t = (model_t - model_prev_0)
333
+ x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
334
+
335
+ return x_t, model_t
336
+
337
+ def multistep_uni_pc_bh_sde_update(self, x, model_prev_list, t_prev_list, t, order, level = 0, x_t=None, use_corrector=True):
338
+ # print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
339
+ ns = self.noise_schedule
340
+ assert order <= len(model_prev_list)
341
+ dims = x.dim()
342
+
343
+ # first compute rks
344
+ t_prev_0 = t_prev_list[-1]
345
+ lambda_prev_0 = ns.marginal_lambda(t_prev_0)
346
+ lambda_t = ns.marginal_lambda(t)
347
+ model_prev_0 = model_prev_list[-1]
348
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
349
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
350
+ alpha_t = torch.exp(log_alpha_t)
351
+
352
+ h = lambda_t - lambda_prev_0
353
+ z = torch.randn(x.shape, device=self.device)
354
+ z = sigma_t * torch.sqrt(torch.expm1(2.0 * h[0])) * z
355
+
356
+ rks = []
357
+ D1s = []
358
+ for i in range(1, order):
359
+ t_prev_i = t_prev_list[-(i + 1)]
360
+ model_prev_i = model_prev_list[-(i + 1)]
361
+ lambda_prev_i = ns.marginal_lambda(t_prev_i)
362
+ rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
363
+ rks.append(rk)
364
+ D1s.append((model_prev_i - model_prev_0) / rk)
365
+
366
+ rks.append(1.)
367
+ rks = torch.tensor(rks, device=x.device)
368
+
369
+ R = []
370
+ b = []
371
+
372
+ hh = h[0]
373
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
374
+ h_phi_k = h_phi_1 / hh - 1
375
+
376
+ factorial_i = 1
377
+
378
+ if True:
379
+ B_h = hh
380
+ else:
381
+ B_h = torch.expm1(hh)
382
+
383
+ for i in range(1, order + 1):
384
+ R.append(torch.pow(rks, i - 1))
385
+ b.append(h_phi_k * factorial_i / B_h)
386
+ factorial_i *= (i + 1)
387
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
388
+
389
+ R = torch.stack(R)
390
+ b = torch.tensor(b, device=x.device)
391
+
392
+ # now predictor
393
+ use_predictor = len(D1s) > 0 and x_t is None
394
+ if len(D1s) > 0:
395
+ D1s = torch.stack(D1s, dim=1) # (B, K)
396
+ if x_t is None:
397
+ # for order 2, we use a simplified version
398
+ if order == 2:
399
+ rhos_p = torch.tensor([0.5], device=b.device)
400
+ else:
401
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
402
+ else:
403
+ D1s = None
404
+
405
+ if use_corrector:
406
+ # print('using corrector')
407
+ # for order 1, we use a simplified version
408
+ if order == 1:
409
+ rhos_c = torch.tensor([0.5], device=b.device)
410
+ else:
411
+ rhos_c = torch.linalg.solve(R, b)
412
+
413
+ model_t = None
414
+
415
+ x_t_ = (
416
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
417
+ - expand_dims(sigma_t * h_phi_1, dims) * (1 + level) * model_prev_0
418
+ )
419
+ if x_t is None:
420
+ if use_predictor:
421
+ pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
422
+ else:
423
+ pred_res = 0
424
+
425
+ x_t_p = (
426
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
427
+ - expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
428
+ )
429
+ x_t = x_t_p - expand_dims(sigma_t * B_h, dims) * pred_res
430
+
431
+ if use_corrector:
432
+ model_t = self.noise_prediction_fn(x_t, t)
433
+ if D1s is not None:
434
+ corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
435
+ else:
436
+ corr_res = 0
437
+ D1_t = (model_t - model_prev_0)
438
+ x_t = x_t_ - (1 + level) * expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t) + z * level
439
+
440
+ return x_t, model_t
441
+
442
+
443
+ def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
444
+ # print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
445
+ ns = self.noise_schedule
446
+ assert order <= len(model_prev_list)
447
+ dims = x.dim()
448
+ # first compute rks
449
+ t_prev_0 = t_prev_list[-1]
450
+ lambda_prev_0 = ns.marginal_lambda(t_prev_0)
451
+ lambda_t = ns.marginal_lambda(t)
452
+ model_prev_0 = model_prev_list[-1]
453
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
454
+ log_alpha_t = ns.marginal_log_mean_coeff(t)
455
+ alpha_t = torch.exp(log_alpha_t)
456
+
457
+ h = lambda_t - lambda_prev_0
458
+
459
+ rks = []
460
+ D1s = []
461
+ for i in range(1, order):
462
+ t_prev_i = t_prev_list[-(i + 1)]
463
+ model_prev_i = model_prev_list[-(i + 1)]
464
+ lambda_prev_i = ns.marginal_lambda(t_prev_i)
465
+ rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
466
+ rks.append(rk)
467
+ D1s.append((model_prev_i - model_prev_0) / rk)
468
+
469
+ rks.append(1.)
470
+ rks = torch.tensor(rks, device=x.device)
471
+
472
+ K = len(rks)
473
+ # build C matrix
474
+ C = []
475
+
476
+ col = torch.ones_like(rks)
477
+ for k in range(1, K + 1):
478
+ C.append(col)
479
+ col = col * rks / (k + 1)
480
+ C = torch.stack(C, dim=1)
481
+
482
+ if len(D1s) > 0:
483
+ D1s = torch.stack(D1s, dim=1) # (B, K)
484
+ C_inv_p = torch.linalg.inv(C[:-1, :-1])
485
+ A_p = C_inv_p
486
+
487
+ if use_corrector:
488
+ # print('using corrector')
489
+ C_inv = torch.linalg.inv(C)
490
+ A_c = C_inv
491
+
492
+ hh = h
493
+ h_phi_1 = torch.expm1(hh)
494
+ h_phi_ks = []
495
+ factorial_k = 1
496
+ h_phi_k = h_phi_1
497
+ for k in range(1, K + 2):
498
+ h_phi_ks.append(h_phi_k)
499
+ h_phi_k = h_phi_k / hh - 1 / factorial_k
500
+ factorial_k *= (k + 1)
501
+
502
+ model_t = None
503
+ if True:
504
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
505
+ x_t_ = (
506
+ expand_dims((torch.exp(log_alpha_t - log_alpha_prev_0)),dims) * x
507
+ - expand_dims((sigma_t * h_phi_1),dims) * model_prev_0
508
+ )
509
+ # now predictor
510
+ x_t = x_t_
511
+ if len(D1s) > 0:
512
+ # compute the residuals for predictor
513
+ for k in range(K - 1):
514
+ x_t = x_t - expand_dims(sigma_t * h_phi_ks[k + 1],dims) * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
515
+ # now corrector
516
+ if use_corrector:
517
+ model_t = self.noise_prediction_fn(x_t, t)
518
+ D1_t = (model_t - model_prev_0)
519
+ x_t = x_t_
520
+ k = 0
521
+ for k in range(K - 1):
522
+ x_t = x_t - expand_dims(sigma_t * h_phi_ks[k + 1],dims) * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
523
+ x_t = x_t - expand_dims(sigma_t * h_phi_ks[K],dims) * (D1_t * A_c[k][-1])
524
+ return x_t, model_t
525
+
526
+ def sample(
527
+ self,
528
+ x,
529
+ model_fn,
530
+ order,
531
+ use_corrector,
532
+ lower_order_final,
533
+ start_free_u_step=None,
534
+ free_u_apply_callback=None,
535
+ free_u_stop_callback=None,
536
+ npnet_x = None,
537
+ npnet_scale = None,
538
+ half=False,
539
+ return_intermediate=False,
540
+ ):
541
+ self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
542
+ steps = self.steps
543
+ vec_t = self.timesteps[0].expand((x.shape[0]))
544
+ if free_u_stop_callback is not None:
545
+ free_u_stop_callback()
546
+ if start_free_u_step is not None and 0 == start_free_u_step and free_u_apply_callback is not None:
547
+ free_u_apply_callback()
548
+ has_called_free_u = True
549
+ if not self.use_afs:
550
+ fir_output = self.noise_prediction_fn(x, vec_t)
551
+ else:
552
+ fir_output = x * 0.97 # ultilize npnet there in the future
553
+ if npnet_x is not None and npnet_scale is not None:
554
+ fir_output = npnet_x
555
+ # fir_output = fir_output - npnet_scale * (npnet_out - fir_output) #guidance_scale * (noise - noise_uncond)
556
+ x = fir_output.clone().detach().to(fir_output.device)
557
+
558
+
559
+ model_prev_list = [fir_output]
560
+ full_cache = [fir_output]
561
+ t_prev_list = [vec_t]
562
+ has_called_free_u = False
563
+ for init_order in range(1, order):
564
+ if start_free_u_step is not None and init_order == start_free_u_step and free_u_apply_callback is not None and (not has_called_free_u):
565
+ free_u_apply_callback()
566
+ has_called_free_u = True
567
+ vec_t = self.timesteps[init_order].expand(x.shape[0])
568
+ x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
569
+ if model_x is None:
570
+ model_x = self.noise_prediction_fn(x, vec_t)
571
+ x = model_x.clone().detach().to(torch.float32).to(model_x.device)
572
+ full_cache.append(x)
573
+ model_prev_list.append(model_x)
574
+ t_prev_list.append(vec_t)
575
+
576
+ for step in range(order, steps + 1):
577
+ if start_free_u_step is not None and step == start_free_u_step and free_u_apply_callback is not None and (not has_called_free_u):
578
+ free_u_apply_callback()
579
+ vec_t = self.timesteps[step].expand(x.shape[0])
580
+ if lower_order_final:
581
+ step_order = min(order, steps + 1 - step)
582
+ else:
583
+ step_order = order
584
+ # print('this step order:', step_order)
585
+ if step == steps:
586
+ # print('do not run corrector at the last step')
587
+ use_corrector = False
588
+ else:
589
+ use_corrector = True
590
+ x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
591
+ for i in range(order - 1):
592
+ t_prev_list[i] = t_prev_list[i + 1]
593
+ model_prev_list[i] = model_prev_list[i + 1]
594
+ t_prev_list[-1] = vec_t
595
+ # We do not need to evaluate the final model value.
596
+ full_cache.append(x)
597
+ if step < steps:
598
+ if model_x is None:
599
+ model_x = self.noise_prediction_fn(x, vec_t)
600
+ model_prev_list[-1] = model_x
601
+ return x, full_cache
602
+ def sample_mix(
603
+ self,
604
+ x,
605
+ model_fn,
606
+ order,
607
+ use_corrector,
608
+ lower_order_final,
609
+ start_free_u_step=None,
610
+ free_u_apply_callback=None,
611
+ free_u_stop_callback=None,
612
+ noise_level = 0.1,
613
+ half=False,
614
+ return_intermediate=False,
615
+ ):
616
+ self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
617
+ steps = self.steps
618
+ vec_t = self.timesteps[0].expand((x.shape[0]))
619
+ fir_output = self.noise_prediction_fn(x, vec_t)
620
+ model_prev_list = [fir_output]
621
+ full_cache = [fir_output]
622
+ t_prev_list = [vec_t]
623
+ has_called_free_u = False
624
+ if free_u_stop_callback is not None:
625
+ free_u_stop_callback()
626
+ for init_order in range(1, order):
627
+ if start_free_u_step is not None and init_order == start_free_u_step and free_u_apply_callback is not None:
628
+ free_u_apply_callback()
629
+ has_called_free_u = True
630
+ vec_t = self.timesteps[init_order].expand(x.shape[0])
631
+ if start_free_u_step is not None and init_order >= start_free_u_step and free_u_apply_callback is not None:
632
+ x, model_x = self.multistep_uni_pc_sde_update(x
633
+ , model_prev_list
634
+ , t_prev_list
635
+ , vec_t
636
+ , init_order
637
+ , use_corrector=True
638
+ ,level=noise_level)
639
+ else:
640
+ x, model_x = self.multistep_uni_pc_sde_update(x
641
+ , model_prev_list
642
+ , t_prev_list
643
+ , vec_t
644
+ , init_order
645
+ , use_corrector=True
646
+ ,level=0.0)
647
+ if model_x is None:
648
+ model_x = self.noise_prediction_fn(x, vec_t)
649
+ x = model_x.clone().detach().to(torch.float32).to(model_x.device)
650
+ full_cache.append(x)
651
+ model_prev_list.append(model_x)
652
+ t_prev_list.append(vec_t)
653
+
654
+ if free_u_stop_callback is not None:
655
+ free_u_stop_callback()
656
+ for step in range(order, steps + 1):
657
+ if start_free_u_step is not None and step == start_free_u_step and free_u_apply_callback is not None and (not has_called_free_u):
658
+ free_u_apply_callback()
659
+ vec_t = self.timesteps[step].expand(x.shape[0])
660
+ if lower_order_final:
661
+ step_order = min(order, steps + 1 - step)
662
+ else:
663
+ step_order = order
664
+ # print('this step order:', step_order)
665
+ if step == steps:
666
+ # print('do not run corrector at the last step')
667
+ use_corrector = False
668
+ else:
669
+ use_corrector = True
670
+ if start_free_u_step is not None and step >= start_free_u_step and free_u_apply_callback is not None:
671
+ x, model_x = self.multistep_uni_pc_sde_update(x
672
+ , model_prev_list
673
+ , t_prev_list
674
+ , vec_t
675
+ , step_order
676
+ , use_corrector=use_corrector
677
+ , level=noise_level)
678
+ else:
679
+ x, model_x = self.multistep_uni_pc_sde_update(x
680
+ , model_prev_list
681
+ , t_prev_list
682
+ , vec_t
683
+ , step_order
684
+ , use_corrector=use_corrector
685
+ , level=0.0)
686
+ for i in range(order - 1):
687
+ t_prev_list[i] = t_prev_list[i + 1]
688
+ model_prev_list[i] = model_prev_list[i + 1]
689
+ t_prev_list[-1] = vec_t
690
+ # We do not need to evaluate the final model value.
691
+ full_cache.append(x)
692
+ if step < steps:
693
+ if model_x is None:
694
+ model_x = self.noise_prediction_fn(x, vec_t)
695
+ model_prev_list[-1] = model_x
696
+ return x, full_cache
697
+
698
+
699
+
700
+
701
+ #############################################################
702
+ # other utility functions
703
+ #############################################################
704
+
705
+ def interpolate_fn(x, xp, yp):
706
+ """
707
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
708
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
709
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
710
+
711
+ Args:
712
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
713
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
714
+ yp: PyTorch tensor with shape [C, K].
715
+ Returns:
716
+ The function values f(x), with shape [N, C].
717
+ """
718
+ N, K = x.shape[0], xp.shape[1]
719
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
720
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
721
+ x_idx = torch.argmin(x_indices, dim=2)
722
+ cand_start_idx = x_idx - 1
723
+ start_idx = torch.where(
724
+ torch.eq(x_idx, 0),
725
+ torch.tensor(1, device=x.device),
726
+ torch.where(
727
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
728
+ ),
729
+ )
730
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
731
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
732
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
733
+ start_idx2 = torch.where(
734
+ torch.eq(x_idx, 0),
735
+ torch.tensor(0, device=x.device),
736
+ torch.where(
737
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
738
+ ),
739
+ )
740
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
741
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
742
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
743
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
744
+ return cand
745
+
746
+
747
+ def expand_dims(v, dims):
748
+ """
749
+ Expand the tensor `v` to the dim `dims`.
750
+
751
+ Args:
752
+ `v`: a PyTorch tensor with shape [N].
753
+ `dim`: a `int`.
754
+ Returns:
755
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
756
+ """
757
+ return v[(...,) + (None,)*(dims - 1)]