Hanke Chen commited on
Commit
57e6edd
·
0 Parent(s):

:book: initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ *.pt
2
+ *.yaml
3
+ converted
4
+ __pycache__
scripts/README.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Convert original weights to diffusers
2
+
3
+ Download original MVDream checkpoint under `ckpts` through one of the following sources:
4
+
5
+ ```bash
6
+ # for sd-v1.5 (recommended for production)
7
+ wget https://huggingface.co/MVDream/MVDream/resolve/main/sd-v1.5-4view.pt .
8
+ wget https://raw.githubusercontent.com/bytedance/MVDream/main/mvdream/configs/sd-v1.yaml .
9
+
10
+ # for sd-v2.1 (recommended for publication)
11
+ wget https://huggingface.co/MVDream/MVDream/resolve/main/sd-v2.1-base-4view.pt .
12
+ wget https://raw.githubusercontent.com/bytedance/MVDream/main/mvdream/configs/sd-v2-base.yaml .
13
+ ```
14
+
15
+ Hugging Face diffusers weights are converted by script:
16
+ ```bash
17
+ mkdir converted
18
+ python ./scripts/convert_mvdream_to_diffusers.py --checkpoint_path ./sd-v1.5-4view.pt --dump_path ./converted --original_config_file ./sd-v1.yaml
19
+ ```
scripts/attention.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn, einsum
6
+ from einops import rearrange, repeat
7
+ from typing import Optional, Any
8
+
9
+ from util import checkpoint
10
+
11
+
12
+ try:
13
+ import xformers
14
+ import xformers.ops
15
+ XFORMERS_IS_AVAILBLE = True
16
+ except:
17
+ XFORMERS_IS_AVAILBLE = False
18
+
19
+ # CrossAttn precision handling
20
+ import os
21
+ _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
22
+
23
+ def exists(val):
24
+ return val is not None
25
+
26
+
27
+ def uniq(arr):
28
+ return{el: True for el in arr}.keys()
29
+
30
+
31
+ def default(val, d):
32
+ if exists(val):
33
+ return val
34
+ return d() if isfunction(d) else d
35
+
36
+
37
+ def max_neg_value(t):
38
+ return -torch.finfo(t.dtype).max
39
+
40
+
41
+ def init_(tensor):
42
+ dim = tensor.shape[-1]
43
+ std = 1 / math.sqrt(dim)
44
+ tensor.uniform_(-std, std)
45
+ return tensor
46
+
47
+
48
+ # feedforward
49
+ class GEGLU(nn.Module):
50
+ def __init__(self, dim_in, dim_out):
51
+ super().__init__()
52
+ self.proj = nn.Linear(dim_in, dim_out * 2)
53
+
54
+ def forward(self, x):
55
+ x, gate = self.proj(x).chunk(2, dim=-1)
56
+ return x * F.gelu(gate)
57
+
58
+
59
+ class FeedForward(nn.Module):
60
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
61
+ super().__init__()
62
+ inner_dim = int(dim * mult)
63
+ dim_out = default(dim_out, dim)
64
+ project_in = nn.Sequential(
65
+ nn.Linear(dim, inner_dim),
66
+ nn.GELU()
67
+ ) if not glu else GEGLU(dim, inner_dim)
68
+
69
+ self.net = nn.Sequential(
70
+ project_in,
71
+ nn.Dropout(dropout),
72
+ nn.Linear(inner_dim, dim_out)
73
+ )
74
+
75
+ def forward(self, x):
76
+ return self.net(x)
77
+
78
+
79
+ def zero_module(module):
80
+ """
81
+ Zero out the parameters of a module and return it.
82
+ """
83
+ for p in module.parameters():
84
+ p.detach().zero_()
85
+ return module
86
+
87
+
88
+ def Normalize(in_channels):
89
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
90
+
91
+
92
+ class SpatialSelfAttention(nn.Module):
93
+ def __init__(self, in_channels):
94
+ super().__init__()
95
+ self.in_channels = in_channels
96
+
97
+ self.norm = Normalize(in_channels)
98
+ self.q = torch.nn.Conv2d(in_channels,
99
+ in_channels,
100
+ kernel_size=1,
101
+ stride=1,
102
+ padding=0)
103
+ self.k = torch.nn.Conv2d(in_channels,
104
+ in_channels,
105
+ kernel_size=1,
106
+ stride=1,
107
+ padding=0)
108
+ self.v = torch.nn.Conv2d(in_channels,
109
+ in_channels,
110
+ kernel_size=1,
111
+ stride=1,
112
+ padding=0)
113
+ self.proj_out = torch.nn.Conv2d(in_channels,
114
+ in_channels,
115
+ kernel_size=1,
116
+ stride=1,
117
+ padding=0)
118
+
119
+ def forward(self, x):
120
+ h_ = x
121
+ h_ = self.norm(h_)
122
+ q = self.q(h_)
123
+ k = self.k(h_)
124
+ v = self.v(h_)
125
+
126
+ # compute attention
127
+ b,c,h,w = q.shape
128
+ q = rearrange(q, 'b c h w -> b (h w) c')
129
+ k = rearrange(k, 'b c h w -> b c (h w)')
130
+ w_ = torch.einsum('bij,bjk->bik', q, k)
131
+
132
+ w_ = w_ * (int(c)**(-0.5))
133
+ w_ = torch.nn.functional.softmax(w_, dim=2)
134
+
135
+ # attend to values
136
+ v = rearrange(v, 'b c h w -> b c (h w)')
137
+ w_ = rearrange(w_, 'b i j -> b j i')
138
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
139
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
140
+ h_ = self.proj_out(h_)
141
+
142
+ return x+h_
143
+
144
+
145
+ class CrossAttention(nn.Module):
146
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
147
+ super().__init__()
148
+ inner_dim = dim_head * heads
149
+ context_dim = default(context_dim, query_dim)
150
+
151
+ self.scale = dim_head ** -0.5
152
+ self.heads = heads
153
+
154
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
155
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
156
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
157
+
158
+ self.to_out = nn.Sequential(
159
+ nn.Linear(inner_dim, query_dim),
160
+ nn.Dropout(dropout)
161
+ )
162
+
163
+ def forward(self, x, context=None, mask=None):
164
+ h = self.heads
165
+
166
+ q = self.to_q(x)
167
+ context = default(context, x)
168
+ k = self.to_k(context)
169
+ v = self.to_v(context)
170
+
171
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
172
+
173
+ # force cast to fp32 to avoid overflowing
174
+ if _ATTN_PRECISION =="fp32":
175
+ with torch.autocast(enabled=False, device_type = 'cuda'):
176
+ q, k = q.float(), k.float()
177
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
178
+ else:
179
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
180
+
181
+ del q, k
182
+
183
+ if exists(mask):
184
+ mask = rearrange(mask, 'b ... -> b (...)')
185
+ max_neg_value = -torch.finfo(sim.dtype).max
186
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
187
+ sim.masked_fill_(~mask, max_neg_value)
188
+
189
+ # attention, what we cannot get enough of
190
+ sim = sim.softmax(dim=-1)
191
+
192
+ out = einsum('b i j, b j d -> b i d', sim, v)
193
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
194
+ return self.to_out(out)
195
+
196
+
197
+ class MemoryEfficientCrossAttention(nn.Module):
198
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
199
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
200
+ super().__init__()
201
+ print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
202
+ f"{heads} heads.")
203
+ inner_dim = dim_head * heads
204
+ context_dim = default(context_dim, query_dim)
205
+
206
+ self.heads = heads
207
+ self.dim_head = dim_head
208
+
209
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
210
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
211
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
212
+
213
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
214
+ self.attention_op: Optional[Any] = None
215
+
216
+ def forward(self, x, context=None, mask=None):
217
+ q = self.to_q(x)
218
+ context = default(context, x)
219
+ k = self.to_k(context)
220
+ v = self.to_v(context)
221
+
222
+ b, _, _ = q.shape
223
+ q, k, v = map(
224
+ lambda t: t.unsqueeze(3)
225
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
226
+ .permute(0, 2, 1, 3)
227
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
228
+ .contiguous(),
229
+ (q, k, v),
230
+ )
231
+
232
+ # actually compute the attention, what we cannot get enough of
233
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
234
+
235
+ if exists(mask):
236
+ raise NotImplementedError
237
+ out = (
238
+ out.unsqueeze(0)
239
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
240
+ .permute(0, 2, 1, 3)
241
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
242
+ )
243
+ return self.to_out(out)
244
+
245
+
246
+ class BasicTransformerBlock(nn.Module):
247
+ ATTENTION_MODES = {
248
+ "softmax": CrossAttention, # vanilla attention
249
+ "softmax-xformers": MemoryEfficientCrossAttention
250
+ }
251
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
252
+ disable_self_attn=False):
253
+ super().__init__()
254
+ attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
255
+ assert attn_mode in self.ATTENTION_MODES
256
+ attn_cls = self.ATTENTION_MODES[attn_mode]
257
+ self.disable_self_attn = disable_self_attn
258
+ self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
259
+ context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
260
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
261
+ self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
262
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
263
+ self.norm1 = nn.LayerNorm(dim)
264
+ self.norm2 = nn.LayerNorm(dim)
265
+ self.norm3 = nn.LayerNorm(dim)
266
+ self.checkpoint = checkpoint
267
+
268
+ def forward(self, x, context=None):
269
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
270
+
271
+ def _forward(self, x, context=None):
272
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
273
+ x = self.attn2(self.norm2(x), context=context) + x
274
+ x = self.ff(self.norm3(x)) + x
275
+ return x
276
+
277
+
278
+ class SpatialTransformer(nn.Module):
279
+ """
280
+ Transformer block for image-like data.
281
+ First, project the input (aka embedding)
282
+ and reshape to b, t, d.
283
+ Then apply standard transformer action.
284
+ Finally, reshape to image
285
+ NEW: use_linear for more efficiency instead of the 1x1 convs
286
+ """
287
+ def __init__(self, in_channels, n_heads, d_head,
288
+ depth=1, dropout=0., context_dim=None,
289
+ disable_self_attn=False, use_linear=False,
290
+ use_checkpoint=True):
291
+ super().__init__()
292
+ if exists(context_dim) and not isinstance(context_dim, list):
293
+ context_dim = [context_dim]
294
+ self.in_channels = in_channels
295
+ inner_dim = n_heads * d_head
296
+ self.norm = Normalize(in_channels)
297
+ if not use_linear:
298
+ self.proj_in = nn.Conv2d(in_channels,
299
+ inner_dim,
300
+ kernel_size=1,
301
+ stride=1,
302
+ padding=0)
303
+ else:
304
+ self.proj_in = nn.Linear(in_channels, inner_dim)
305
+
306
+ self.transformer_blocks = nn.ModuleList(
307
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
308
+ disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
309
+ for d in range(depth)]
310
+ )
311
+ if not use_linear:
312
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
313
+ in_channels,
314
+ kernel_size=1,
315
+ stride=1,
316
+ padding=0))
317
+ else:
318
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
319
+ self.use_linear = use_linear
320
+
321
+ def forward(self, x, context=None):
322
+ # note: if no context is given, cross-attention defaults to self-attention
323
+ if not isinstance(context, list):
324
+ context = [context]
325
+ b, c, h, w = x.shape
326
+ x_in = x
327
+ x = self.norm(x)
328
+ if not self.use_linear:
329
+ x = self.proj_in(x)
330
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
331
+ if self.use_linear:
332
+ x = self.proj_in(x)
333
+ for i, block in enumerate(self.transformer_blocks):
334
+ x = block(x, context=context[i])
335
+ if self.use_linear:
336
+ x = self.proj_out(x)
337
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
338
+ if not self.use_linear:
339
+ x = self.proj_out(x)
340
+ return x + x_in
341
+
342
+
343
+ class BasicTransformerBlock3D(BasicTransformerBlock):
344
+
345
+ def forward(self, x, context=None, num_frames=1):
346
+ return checkpoint(self._forward, (x, context, num_frames), self.parameters(), self.checkpoint)
347
+
348
+ def _forward(self, x, context=None, num_frames=1):
349
+ x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
350
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
351
+ x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
352
+ x = self.attn2(self.norm2(x), context=context) + x
353
+ x = self.ff(self.norm3(x)) + x
354
+ return x
355
+
356
+
357
+ class SpatialTransformer3D(nn.Module):
358
+ ''' 3D self-attention '''
359
+ def __init__(self, in_channels, n_heads, d_head,
360
+ depth=1, dropout=0., context_dim=None,
361
+ disable_self_attn=False, use_linear=False,
362
+ use_checkpoint=True):
363
+ super().__init__()
364
+ if exists(context_dim) and not isinstance(context_dim, list):
365
+ context_dim = [context_dim]
366
+ self.in_channels = in_channels
367
+ inner_dim = n_heads * d_head
368
+ self.norm = Normalize(in_channels)
369
+ if not use_linear:
370
+ self.proj_in = nn.Conv2d(in_channels,
371
+ inner_dim,
372
+ kernel_size=1,
373
+ stride=1,
374
+ padding=0)
375
+ else:
376
+ self.proj_in = nn.Linear(in_channels, inner_dim)
377
+
378
+ self.transformer_blocks = nn.ModuleList(
379
+ [BasicTransformerBlock3D(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
380
+ disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
381
+ for d in range(depth)]
382
+ )
383
+ if not use_linear:
384
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
385
+ in_channels,
386
+ kernel_size=1,
387
+ stride=1,
388
+ padding=0))
389
+ else:
390
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
391
+ self.use_linear = use_linear
392
+
393
+ def forward(self, x, context=None, num_frames=1):
394
+ # note: if no context is given, cross-attention defaults to self-attention
395
+ if not isinstance(context, list):
396
+ context = [context]
397
+ b, c, h, w = x.shape
398
+ x_in = x
399
+ x = self.norm(x)
400
+ if not self.use_linear:
401
+ x = self.proj_in(x)
402
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
403
+ if self.use_linear:
404
+ x = self.proj_in(x)
405
+ for i, block in enumerate(self.transformer_blocks):
406
+ x = block(x, context=context[i], num_frames=num_frames)
407
+ if self.use_linear:
408
+ x = self.proj_out(x)
409
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
410
+ if not self.use_linear:
411
+ x = self.proj_out(x)
412
+ return x + x_in
scripts/convert_mvdream_to_diffusers.py ADDED
@@ -0,0 +1,794 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/huggingface/diffusers/blob/bc691231360a4cbc7d19a58742ebb8ed0f05e027/scripts/convert_original_stable_diffusion_to_diffusers.py
2
+
3
+ import argparse
4
+ import torch
5
+ import sys
6
+ sys.path.insert(0, '../')
7
+
8
+ from transformers import (
9
+ CLIPImageProcessor,
10
+ CLIPVisionModelWithProjection,
11
+ )
12
+
13
+ from diffusers.models import (
14
+ AutoencoderKL,
15
+ UNet2DConditionModel,
16
+ )
17
+ from diffusers.schedulers import DDIMScheduler
18
+ from diffusers.utils import logging
19
+
20
+ from accelerate import init_empty_weights
21
+ from accelerate.utils import set_module_tensor_to_device
22
+ from rich import print, print_json
23
+ from models import MultiViewUNetModel
24
+ from pipeline_mvdream import MVDreamStableDiffusionPipeline
25
+ from transformers import CLIPTokenizer, CLIPTextModel
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ # def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
30
+ # """
31
+ # Creates a config for the diffusers based on the config of the LDM model.
32
+ # """
33
+ # if controlnet:
34
+ # unet_params = original_config.model.params.control_stage_config.params
35
+ # else:
36
+ # if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None:
37
+ # unet_params = original_config.model.params.unet_config.params
38
+ # else:
39
+ # unet_params = original_config.model.params.network_config.params
40
+
41
+ # vae_params = original_config.model.params.first_stage_config.params.ddconfig
42
+
43
+ # block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
44
+
45
+ # down_block_types = []
46
+ # resolution = 1
47
+ # for i in range(len(block_out_channels)):
48
+ # block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
49
+ # down_block_types.append(block_type)
50
+ # if i != len(block_out_channels) - 1:
51
+ # resolution *= 2
52
+
53
+ # up_block_types = []
54
+ # for i in range(len(block_out_channels)):
55
+ # block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
56
+ # up_block_types.append(block_type)
57
+ # resolution //= 2
58
+
59
+ # if unet_params.transformer_depth is not None:
60
+ # transformer_layers_per_block = (
61
+ # unet_params.transformer_depth
62
+ # if isinstance(unet_params.transformer_depth, int)
63
+ # else list(unet_params.transformer_depth)
64
+ # )
65
+ # else:
66
+ # transformer_layers_per_block = 1
67
+
68
+ # vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
69
+
70
+ # head_dim = unet_params.num_heads if "num_heads" in unet_params else None
71
+ # use_linear_projection = (
72
+ # unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
73
+ # )
74
+ # if use_linear_projection:
75
+ # # stable diffusion 2-base-512 and 2-768
76
+ # if head_dim is None:
77
+ # head_dim_mult = unet_params.model_channels // unet_params.num_head_channels
78
+ # head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)]
79
+
80
+ # class_embed_type = None
81
+ # addition_embed_type = None
82
+ # addition_time_embed_dim = None
83
+ # projection_class_embeddings_input_dim = None
84
+ # context_dim = None
85
+
86
+ # if unet_params.context_dim is not None:
87
+ # context_dim = (
88
+ # unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0]
89
+ # )
90
+
91
+ # if "num_classes" in unet_params:
92
+ # if unet_params.num_classes == "sequential":
93
+ # if context_dim in [2048, 1280]:
94
+ # # SDXL
95
+ # addition_embed_type = "text_time"
96
+ # addition_time_embed_dim = 256
97
+ # else:
98
+ # class_embed_type = "projection"
99
+ # assert "adm_in_channels" in unet_params
100
+ # projection_class_embeddings_input_dim = unet_params.adm_in_channels
101
+ # else:
102
+ # raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")
103
+
104
+ # config = {
105
+ # "sample_size": image_size // vae_scale_factor,
106
+ # "in_channels": unet_params.in_channels,
107
+ # "down_block_types": tuple(down_block_types),
108
+ # "block_out_channels": tuple(block_out_channels),
109
+ # "layers_per_block": unet_params.num_res_blocks,
110
+ # "cross_attention_dim": context_dim,
111
+ # "attention_head_dim": head_dim,
112
+ # "use_linear_projection": use_linear_projection,
113
+ # "class_embed_type": class_embed_type,
114
+ # "addition_embed_type": addition_embed_type,
115
+ # "addition_time_embed_dim": addition_time_embed_dim,
116
+ # "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
117
+ # "transformer_layers_per_block": transformer_layers_per_block,
118
+ # }
119
+
120
+ # if controlnet:
121
+ # config["conditioning_channels"] = unet_params.hint_channels
122
+ # else:
123
+ # config["out_channels"] = unet_params.out_channels
124
+ # config["up_block_types"] = tuple(up_block_types)
125
+
126
+ # return config
127
+
128
+
129
+ def assign_to_checkpoint(
130
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
131
+ ):
132
+ """
133
+ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
134
+ attention layers, and takes into account additional replacements that may arise.
135
+ Assigns the weights to the new checkpoint.
136
+ """
137
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
138
+
139
+ # Splits the attention layers into three variables.
140
+ if attention_paths_to_split is not None:
141
+ for path, path_map in attention_paths_to_split.items():
142
+ old_tensor = old_checkpoint[path]
143
+ channels = old_tensor.shape[0] // 3
144
+
145
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
146
+
147
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
148
+
149
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
150
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
151
+
152
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
153
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
154
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
155
+
156
+ for path in paths:
157
+ new_path = path["new"]
158
+
159
+ # These have already been assigned
160
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
161
+ continue
162
+
163
+ # Global renaming happens here
164
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
165
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
166
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
167
+
168
+ if additional_replacements is not None:
169
+ for replacement in additional_replacements:
170
+ new_path = new_path.replace(replacement["old"], replacement["new"])
171
+
172
+ # proj_attn.weight has to be converted from conv 1D to linear
173
+ is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path)
174
+ shape = old_checkpoint[path["old"]].shape
175
+ if is_attn_weight and len(shape) == 3:
176
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
177
+ elif is_attn_weight and len(shape) == 4:
178
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
179
+ else:
180
+ checkpoint[new_path] = old_checkpoint[path["old"]]
181
+
182
+
183
+ def shave_segments(path, n_shave_prefix_segments=1):
184
+ """
185
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
186
+ """
187
+ if n_shave_prefix_segments >= 0:
188
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
189
+ else:
190
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
191
+
192
+
193
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
194
+ """
195
+ Updates paths inside resnets to the new naming scheme (local renaming)
196
+ """
197
+ mapping = []
198
+ for old_item in old_list:
199
+ new_item = old_item.replace("in_layers.0", "norm1")
200
+ new_item = new_item.replace("in_layers.2", "conv1")
201
+
202
+ new_item = new_item.replace("out_layers.0", "norm2")
203
+ new_item = new_item.replace("out_layers.3", "conv2")
204
+
205
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
206
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
207
+
208
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
209
+
210
+ mapping.append({"old": old_item, "new": new_item})
211
+
212
+ return mapping
213
+
214
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
215
+ """
216
+ Updates paths inside attentions to the new naming scheme (local renaming)
217
+ """
218
+ mapping = []
219
+ for old_item in old_list:
220
+ new_item = old_item
221
+
222
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
223
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
224
+
225
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
226
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
227
+
228
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
229
+
230
+ mapping.append({"old": old_item, "new": new_item})
231
+
232
+ return mapping
233
+
234
+ # def convert_ldm_unet_checkpoint(
235
+ # checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False
236
+ # ):
237
+ # """
238
+ # Takes a state dict and a config, and returns a converted checkpoint.
239
+ # """
240
+
241
+ # if skip_extract_state_dict:
242
+ # unet_state_dict = checkpoint
243
+ # else:
244
+ # # extract state_dict for UNet
245
+ # unet_state_dict = {}
246
+ # keys = list(checkpoint.keys())
247
+
248
+ # if controlnet:
249
+ # unet_key = "control_model."
250
+ # else:
251
+ # unet_key = "model.diffusion_model."
252
+
253
+ # # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
254
+ # if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
255
+ # logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.")
256
+ # logger.warning(
257
+ # "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
258
+ # " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
259
+ # )
260
+ # for key in keys:
261
+ # if key.startswith("model.diffusion_model"):
262
+ # flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
263
+ # unet_state_dict[key.replace(unet_key, "")] = checkpoint[flat_ema_key]
264
+ # else:
265
+ # if sum(k.startswith("model_ema") for k in keys) > 100:
266
+ # logger.warning(
267
+ # "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
268
+ # " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
269
+ # )
270
+
271
+ # for key in keys:
272
+ # if key.startswith(unet_key):
273
+ # unet_state_dict[key.replace(unet_key, "")] = checkpoint[key]
274
+
275
+ # new_checkpoint = {}
276
+
277
+ # new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
278
+ # new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
279
+ # new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
280
+ # new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
281
+
282
+ # if config["class_embed_type"] is None:
283
+ # # No parameters to port
284
+ # ...
285
+ # elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
286
+ # new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
287
+ # new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
288
+ # new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
289
+ # new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
290
+ # else:
291
+ # raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
292
+
293
+ # if config["addition_embed_type"] == "text_time":
294
+ # new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
295
+ # new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
296
+ # new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
297
+ # new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
298
+
299
+ # new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
300
+ # new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
301
+
302
+ # if not controlnet:
303
+ # new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
304
+ # new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
305
+ # new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
306
+ # new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
307
+
308
+ # # Retrieves the keys for the input blocks only
309
+ # num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
310
+ # input_blocks = {
311
+ # layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
312
+ # for layer_id in range(num_input_blocks)
313
+ # }
314
+
315
+ # # Retrieves the keys for the middle blocks only
316
+ # num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
317
+ # middle_blocks = {
318
+ # layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
319
+ # for layer_id in range(num_middle_blocks)
320
+ # }
321
+
322
+ # # Retrieves the keys for the output blocks only
323
+ # num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
324
+ # output_blocks = {
325
+ # layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
326
+ # for layer_id in range(num_output_blocks)
327
+ # }
328
+
329
+ # for i in range(1, num_input_blocks):
330
+ # block_id = (i - 1) // (config["layers_per_block"] + 1)
331
+ # layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
332
+
333
+ # resnets = [
334
+ # key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
335
+ # ]
336
+ # attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
337
+
338
+ # if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
339
+ # new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
340
+ # f"input_blocks.{i}.0.op.weight"
341
+ # )
342
+ # new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
343
+ # f"input_blocks.{i}.0.op.bias"
344
+ # )
345
+
346
+ # paths = renew_resnet_paths(resnets)
347
+ # meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
348
+ # assign_to_checkpoint(
349
+ # paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
350
+ # )
351
+
352
+ # if len(attentions):
353
+ # paths = renew_attention_paths(attentions)
354
+ # meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
355
+ # assign_to_checkpoint(
356
+ # paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
357
+ # )
358
+
359
+ # resnet_0 = middle_blocks[0]
360
+ # attentions = middle_blocks[1]
361
+ # resnet_1 = middle_blocks[2]
362
+
363
+ # resnet_0_paths = renew_resnet_paths(resnet_0)
364
+ # assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
365
+
366
+ # resnet_1_paths = renew_resnet_paths(resnet_1)
367
+ # assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
368
+
369
+ # attentions_paths = renew_attention_paths(attentions)
370
+ # meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
371
+ # assign_to_checkpoint(
372
+ # attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
373
+ # )
374
+
375
+ # for i in range(num_output_blocks):
376
+ # block_id = i // (config["layers_per_block"] + 1)
377
+ # layer_in_block_id = i % (config["layers_per_block"] + 1)
378
+ # output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
379
+ # output_block_list = {}
380
+
381
+ # for layer in output_block_layers:
382
+ # layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
383
+ # if layer_id in output_block_list:
384
+ # output_block_list[layer_id].append(layer_name)
385
+ # else:
386
+ # output_block_list[layer_id] = [layer_name]
387
+
388
+ # if len(output_block_list) > 1:
389
+ # resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
390
+ # attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
391
+
392
+ # resnet_0_paths = renew_resnet_paths(resnets)
393
+ # paths = renew_resnet_paths(resnets)
394
+
395
+ # meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
396
+ # assign_to_checkpoint(
397
+ # paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
398
+ # )
399
+
400
+ # output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
401
+ # if ["conv.bias", "conv.weight"] in output_block_list.values():
402
+ # index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
403
+ # new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
404
+ # f"output_blocks.{i}.{index}.conv.weight"
405
+ # ]
406
+ # new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
407
+ # f"output_blocks.{i}.{index}.conv.bias"
408
+ # ]
409
+
410
+ # # Clear attentions as they have been attributed above.
411
+ # if len(attentions) == 2:
412
+ # attentions = []
413
+
414
+ # if len(attentions):
415
+ # paths = renew_attention_paths(attentions)
416
+ # meta_path = {
417
+ # "old": f"output_blocks.{i}.1",
418
+ # "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
419
+ # }
420
+ # assign_to_checkpoint(
421
+ # paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
422
+ # )
423
+ # else:
424
+ # resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
425
+ # for path in resnet_0_paths:
426
+ # old_path = ".".join(["output_blocks", str(i), path["old"]])
427
+ # new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
428
+
429
+ # new_checkpoint[new_path] = unet_state_dict[old_path]
430
+
431
+ # if controlnet:
432
+ # # conditioning embedding
433
+
434
+ # orig_index = 0
435
+
436
+ # new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
437
+ # f"input_hint_block.{orig_index}.weight"
438
+ # )
439
+ # new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
440
+ # f"input_hint_block.{orig_index}.bias"
441
+ # )
442
+
443
+ # orig_index += 2
444
+
445
+ # diffusers_index = 0
446
+
447
+ # while diffusers_index < 6:
448
+ # new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
449
+ # f"input_hint_block.{orig_index}.weight"
450
+ # )
451
+ # new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
452
+ # f"input_hint_block.{orig_index}.bias"
453
+ # )
454
+ # diffusers_index += 1
455
+ # orig_index += 2
456
+
457
+ # new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
458
+ # f"input_hint_block.{orig_index}.weight"
459
+ # )
460
+ # new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
461
+ # f"input_hint_block.{orig_index}.bias"
462
+ # )
463
+
464
+ # # down blocks
465
+ # for i in range(num_input_blocks):
466
+ # new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
467
+ # new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
468
+
469
+ # # mid block
470
+ # new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
471
+ # new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
472
+
473
+ # return new_checkpoint
474
+
475
+
476
+ def create_vae_diffusers_config(original_config, image_size: int):
477
+ """
478
+ Creates a config for the diffusers based on the config of the LDM model.
479
+ """
480
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
481
+ _ = original_config.model.params.first_stage_config.params.embed_dim
482
+
483
+ block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
484
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
485
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
486
+
487
+ config = {
488
+ "sample_size": image_size,
489
+ "in_channels": vae_params.in_channels,
490
+ "out_channels": vae_params.out_ch,
491
+ "down_block_types": tuple(down_block_types),
492
+ "up_block_types": tuple(up_block_types),
493
+ "block_out_channels": tuple(block_out_channels),
494
+ "latent_channels": vae_params.z_channels,
495
+ "layers_per_block": vae_params.num_res_blocks,
496
+ }
497
+ return config
498
+
499
+ def convert_ldm_vae_checkpoint(checkpoint, config):
500
+ # extract state dict for VAE
501
+ vae_state_dict = {}
502
+ vae_key = "first_stage_model."
503
+ keys = list(checkpoint.keys())
504
+ for key in keys:
505
+ if key.startswith(vae_key):
506
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
507
+
508
+ new_checkpoint = {}
509
+
510
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
511
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
512
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
513
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
514
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
515
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
516
+
517
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
518
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
519
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
520
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
521
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
522
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
523
+
524
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
525
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
526
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
527
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
528
+
529
+ # Retrieves the keys for the encoder down blocks only
530
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
531
+ down_blocks = {
532
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
533
+ }
534
+
535
+ # Retrieves the keys for the decoder up blocks only
536
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
537
+ up_blocks = {
538
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
539
+ }
540
+
541
+ for i in range(num_down_blocks):
542
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
543
+
544
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
545
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
546
+ f"encoder.down.{i}.downsample.conv.weight"
547
+ )
548
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
549
+ f"encoder.down.{i}.downsample.conv.bias"
550
+ )
551
+
552
+ paths = renew_vae_resnet_paths(resnets)
553
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
554
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
555
+
556
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
557
+ num_mid_res_blocks = 2
558
+ for i in range(1, num_mid_res_blocks + 1):
559
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
560
+
561
+ paths = renew_vae_resnet_paths(resnets)
562
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
563
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
564
+
565
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
566
+ paths = renew_vae_attention_paths(mid_attentions)
567
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
568
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
569
+ conv_attn_to_linear(new_checkpoint)
570
+
571
+ for i in range(num_up_blocks):
572
+ block_id = num_up_blocks - 1 - i
573
+ resnets = [
574
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
575
+ ]
576
+
577
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
578
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
579
+ f"decoder.up.{block_id}.upsample.conv.weight"
580
+ ]
581
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
582
+ f"decoder.up.{block_id}.upsample.conv.bias"
583
+ ]
584
+
585
+ paths = renew_vae_resnet_paths(resnets)
586
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
587
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
588
+
589
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
590
+ num_mid_res_blocks = 2
591
+ for i in range(1, num_mid_res_blocks + 1):
592
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
593
+
594
+ paths = renew_vae_resnet_paths(resnets)
595
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
596
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
597
+
598
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
599
+ paths = renew_vae_attention_paths(mid_attentions)
600
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
601
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
602
+ conv_attn_to_linear(new_checkpoint)
603
+ return new_checkpoint
604
+
605
+
606
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
607
+ """
608
+ Updates paths inside resnets to the new naming scheme (local renaming)
609
+ """
610
+ mapping = []
611
+ for old_item in old_list:
612
+ new_item = old_item
613
+
614
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
615
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
616
+
617
+ mapping.append({"old": old_item, "new": new_item})
618
+
619
+ return mapping
620
+
621
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
622
+ """
623
+ Updates paths inside attentions to the new naming scheme (local renaming)
624
+ """
625
+ mapping = []
626
+ for old_item in old_list:
627
+ new_item = old_item
628
+
629
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
630
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
631
+
632
+ new_item = new_item.replace("q.weight", "to_q.weight")
633
+ new_item = new_item.replace("q.bias", "to_q.bias")
634
+
635
+ new_item = new_item.replace("k.weight", "to_k.weight")
636
+ new_item = new_item.replace("k.bias", "to_k.bias")
637
+
638
+ new_item = new_item.replace("v.weight", "to_v.weight")
639
+ new_item = new_item.replace("v.bias", "to_v.bias")
640
+
641
+ new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
642
+ new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
643
+
644
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
645
+
646
+ mapping.append({"old": old_item, "new": new_item})
647
+
648
+ return mapping
649
+
650
+
651
+ def conv_attn_to_linear(checkpoint):
652
+ keys = list(checkpoint.keys())
653
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
654
+ for key in keys:
655
+ if ".".join(key.split(".")[-2:]) in attn_keys:
656
+ if checkpoint[key].ndim > 2:
657
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
658
+ elif "proj_attn.weight" in key:
659
+ if checkpoint[key].ndim > 2:
660
+ checkpoint[key] = checkpoint[key][:, :, 0]
661
+
662
+
663
+ def convert_from_original_mvdream_ckpt(
664
+ checkpoint_path,
665
+ original_config_file,
666
+ extract_ema,
667
+ device
668
+ ):
669
+ checkpoint = torch.load(checkpoint_path, map_location=device)
670
+ print(f"Checkpoint: {checkpoint.keys()}")
671
+ torch.cuda.empty_cache()
672
+
673
+ from omegaconf import OmegaConf
674
+
675
+ original_config = OmegaConf.load(original_config_file)
676
+ print(f"Original Config: {original_config}")
677
+ prediction_type = "epsilon"
678
+ image_size = 256
679
+ num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000
680
+ beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02
681
+ beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085
682
+ scheduler = DDIMScheduler(
683
+ beta_end=beta_end,
684
+ beta_schedule="scaled_linear",
685
+ beta_start=beta_start,
686
+ num_train_timesteps=num_train_timesteps,
687
+ steps_offset=1,
688
+ clip_sample=False,
689
+ set_alpha_to_one=False,
690
+ prediction_type=prediction_type,
691
+ )
692
+ scheduler.register_to_config(clip_sample=False)
693
+
694
+ # Convert the UNet2DConditionModel model.
695
+ # upcast_attention = None
696
+ # unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
697
+ # unet_config["upcast_attention"] = upcast_attention
698
+ # with init_empty_weights():
699
+ # unet = UNet2DConditionModel(**unet_config)
700
+ # converted_unet_checkpoint = convert_ldm_unet_checkpoint(
701
+ # checkpoint, unet_config, path=None, extract_ema=extract_ema
702
+ # )
703
+ print(f"Unet Config: {original_config.model.params.unet_config.params}")
704
+ unet: MultiViewUNetModel = MultiViewUNetModel(**original_config.model.params.unet_config.params)
705
+ unet.load_state_dict({
706
+ key.replace("model.diffusion_model.", ""): value for key, value in checkpoint.items() if key.replace("model.diffusion_model.", "") in unet.state_dict()
707
+ })
708
+ for param_name, param in unet.state_dict().items():
709
+ set_module_tensor_to_device(unet, param_name, "cuda:0", value=param)
710
+
711
+ # Convert the VAE model.
712
+ vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
713
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
714
+
715
+ if (
716
+ "model" in original_config
717
+ and "params" in original_config.model
718
+ and "scale_factor" in original_config.model.params
719
+ ):
720
+ vae_scaling_factor = original_config.model.params.scale_factor
721
+ else:
722
+ vae_scaling_factor = 0.18215 # default SD scaling factor
723
+
724
+ vae_config["scaling_factor"] = vae_scaling_factor
725
+
726
+ with init_empty_weights():
727
+ vae = AutoencoderKL(**vae_config)
728
+
729
+ tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
730
+ text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device=torch.device("cuda:0")) # type: ignore
731
+
732
+ for param_name, param in converted_vae_checkpoint.items():
733
+ set_module_tensor_to_device(vae, param_name, "cuda:0", value=param)
734
+
735
+ pipe = MVDreamStableDiffusionPipeline(
736
+ vae=vae,
737
+ unet=unet,
738
+ tokenizer=tokenizer,
739
+ text_encoder=text_encoder,
740
+ scheduler=scheduler,
741
+ safety_checker=None,
742
+ feature_extractor=None,
743
+ requires_safety_checker=False
744
+ )
745
+
746
+ return pipe
747
+
748
+
749
+ if __name__ == "__main__":
750
+ parser = argparse.ArgumentParser()
751
+
752
+ parser.add_argument(
753
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
754
+ )
755
+ parser.add_argument(
756
+ "--original_config_file",
757
+ default=None,
758
+ type=str,
759
+ help="The YAML config file corresponding to the original architecture.",
760
+ )
761
+ parser.add_argument(
762
+ "--extract_ema",
763
+ action="store_true",
764
+ help=(
765
+ "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
766
+ " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
767
+ " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
768
+ ),
769
+ )
770
+ parser.add_argument(
771
+ "--to_safetensors",
772
+ action="store_true",
773
+ help="Whether to store pipeline in safetensors format or not.",
774
+ )
775
+ parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
776
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
777
+ parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
778
+ args = parser.parse_args()
779
+
780
+ pipe = convert_from_original_mvdream_ckpt(
781
+ checkpoint_path=args.checkpoint_path,
782
+ original_config_file=args.original_config_file,
783
+ extract_ema=args.extract_ema,
784
+ device=args.device,
785
+ )
786
+
787
+ if args.half:
788
+ pipe.to(torch_dtype=torch.float16)
789
+
790
+ out = pipe()
791
+
792
+ assert False
793
+
794
+ pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
scripts/models.py ADDED
@@ -0,0 +1,1214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ import math
3
+
4
+ import numpy as np
5
+ import torch as th
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from util import (
10
+ checkpoint,
11
+ conv_nd,
12
+ linear,
13
+ avg_pool_nd,
14
+ zero_module,
15
+ normalization,
16
+ timestep_embedding,
17
+ )
18
+ from attention import SpatialTransformer, SpatialTransformer3D, exists
19
+
20
+
21
+ # dummy replace
22
+ def convert_module_to_f16(x):
23
+ pass
24
+
25
+ def convert_module_to_f32(x):
26
+ pass
27
+
28
+
29
+ ## go
30
+ class AttentionPool2d(nn.Module):
31
+ """
32
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ spacial_dim: int,
38
+ embed_dim: int,
39
+ num_heads_channels: int,
40
+ output_dim: int = None,
41
+ ):
42
+ super().__init__()
43
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
44
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
45
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
46
+ self.num_heads = embed_dim // num_heads_channels
47
+ self.attention = QKVAttention(self.num_heads)
48
+
49
+ def forward(self, x):
50
+ b, c, *_spatial = x.shape
51
+ x = x.reshape(b, c, -1) # NC(HW)
52
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
53
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
54
+ x = self.qkv_proj(x)
55
+ x = self.attention(x)
56
+ x = self.c_proj(x)
57
+ return x[:, :, 0]
58
+
59
+
60
+ class TimestepBlock(nn.Module):
61
+ """
62
+ Any module where forward() takes timestep embeddings as a second argument.
63
+ """
64
+
65
+ @abstractmethod
66
+ def forward(self, x, emb):
67
+ """
68
+ Apply the module to `x` given `emb` timestep embeddings.
69
+ """
70
+
71
+
72
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
73
+ """
74
+ A sequential module that passes timestep embeddings to the children that
75
+ support it as an extra input.
76
+ """
77
+
78
+ def forward(self, x, emb, context=None, num_frames=1):
79
+ for layer in self:
80
+ if isinstance(layer, TimestepBlock):
81
+ x = layer(x, emb)
82
+ elif isinstance(layer, SpatialTransformer3D):
83
+ x = layer(x, context, num_frames=num_frames)
84
+ elif isinstance(layer, SpatialTransformer):
85
+ x = layer(x, context)
86
+ else:
87
+ x = layer(x)
88
+ return x
89
+
90
+
91
+ class Upsample(nn.Module):
92
+ """
93
+ An upsampling layer with an optional convolution.
94
+ :param channels: channels in the inputs and outputs.
95
+ :param use_conv: a bool determining if a convolution is applied.
96
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
97
+ upsampling occurs in the inner-two dimensions.
98
+ """
99
+
100
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
101
+ super().__init__()
102
+ self.channels = channels
103
+ self.out_channels = out_channels or channels
104
+ self.use_conv = use_conv
105
+ self.dims = dims
106
+ if use_conv:
107
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
108
+
109
+ def forward(self, x):
110
+ assert x.shape[1] == self.channels
111
+ if self.dims == 3:
112
+ x = F.interpolate(
113
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
114
+ )
115
+ else:
116
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
117
+ if self.use_conv:
118
+ x = self.conv(x)
119
+ return x
120
+
121
+ class TransposedUpsample(nn.Module):
122
+ 'Learned 2x upsampling without padding'
123
+ def __init__(self, channels, out_channels=None, ks=5):
124
+ super().__init__()
125
+ self.channels = channels
126
+ self.out_channels = out_channels or channels
127
+
128
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
129
+
130
+ def forward(self,x):
131
+ return self.up(x)
132
+
133
+
134
+ class Downsample(nn.Module):
135
+ """
136
+ A downsampling layer with an optional convolution.
137
+ :param channels: channels in the inputs and outputs.
138
+ :param use_conv: a bool determining if a convolution is applied.
139
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
140
+ downsampling occurs in the inner-two dimensions.
141
+ """
142
+
143
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
144
+ super().__init__()
145
+ self.channels = channels
146
+ self.out_channels = out_channels or channels
147
+ self.use_conv = use_conv
148
+ self.dims = dims
149
+ stride = 2 if dims != 3 else (1, 2, 2)
150
+ if use_conv:
151
+ self.op = conv_nd(
152
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
153
+ )
154
+ else:
155
+ assert self.channels == self.out_channels
156
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
157
+
158
+ def forward(self, x):
159
+ assert x.shape[1] == self.channels
160
+ return self.op(x)
161
+
162
+
163
+ class ResBlock(TimestepBlock):
164
+ """
165
+ A residual block that can optionally change the number of channels.
166
+ :param channels: the number of input channels.
167
+ :param emb_channels: the number of timestep embedding channels.
168
+ :param dropout: the rate of dropout.
169
+ :param out_channels: if specified, the number of out channels.
170
+ :param use_conv: if True and out_channels is specified, use a spatial
171
+ convolution instead of a smaller 1x1 convolution to change the
172
+ channels in the skip connection.
173
+ :param dims: determines if the signal is 1D, 2D, or 3D.
174
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
175
+ :param up: if True, use this block for upsampling.
176
+ :param down: if True, use this block for downsampling.
177
+ """
178
+
179
+ def __init__(
180
+ self,
181
+ channels,
182
+ emb_channels,
183
+ dropout,
184
+ out_channels=None,
185
+ use_conv=False,
186
+ use_scale_shift_norm=False,
187
+ dims=2,
188
+ use_checkpoint=False,
189
+ up=False,
190
+ down=False,
191
+ ):
192
+ super().__init__()
193
+ self.channels = channels
194
+ self.emb_channels = emb_channels
195
+ self.dropout = dropout
196
+ self.out_channels = out_channels or channels
197
+ self.use_conv = use_conv
198
+ self.use_checkpoint = use_checkpoint
199
+ self.use_scale_shift_norm = use_scale_shift_norm
200
+
201
+ self.in_layers = nn.Sequential(
202
+ normalization(channels),
203
+ nn.SiLU(),
204
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
205
+ )
206
+
207
+ self.updown = up or down
208
+
209
+ if up:
210
+ self.h_upd = Upsample(channels, False, dims)
211
+ self.x_upd = Upsample(channels, False, dims)
212
+ elif down:
213
+ self.h_upd = Downsample(channels, False, dims)
214
+ self.x_upd = Downsample(channels, False, dims)
215
+ else:
216
+ self.h_upd = self.x_upd = nn.Identity()
217
+
218
+ self.emb_layers = nn.Sequential(
219
+ nn.SiLU(),
220
+ linear(
221
+ emb_channels,
222
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
223
+ ),
224
+ )
225
+ self.out_layers = nn.Sequential(
226
+ normalization(self.out_channels),
227
+ nn.SiLU(),
228
+ nn.Dropout(p=dropout),
229
+ zero_module(
230
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
231
+ ),
232
+ )
233
+
234
+ if self.out_channels == channels:
235
+ self.skip_connection = nn.Identity()
236
+ elif use_conv:
237
+ self.skip_connection = conv_nd(
238
+ dims, channels, self.out_channels, 3, padding=1
239
+ )
240
+ else:
241
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
242
+
243
+ def forward(self, x, emb):
244
+ """
245
+ Apply the block to a Tensor, conditioned on a timestep embedding.
246
+ :param x: an [N x C x ...] Tensor of features.
247
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
248
+ :return: an [N x C x ...] Tensor of outputs.
249
+ """
250
+ return checkpoint(
251
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
252
+ )
253
+
254
+
255
+ def _forward(self, x, emb):
256
+ if self.updown:
257
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
258
+ h = in_rest(x)
259
+ h = self.h_upd(h)
260
+ x = self.x_upd(x)
261
+ h = in_conv(h)
262
+ else:
263
+ h = self.in_layers(x)
264
+ emb_out = self.emb_layers(emb).type(h.dtype)
265
+ while len(emb_out.shape) < len(h.shape):
266
+ emb_out = emb_out[..., None]
267
+ if self.use_scale_shift_norm:
268
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
269
+ scale, shift = th.chunk(emb_out, 2, dim=1)
270
+ h = out_norm(h) * (1 + scale) + shift
271
+ h = out_rest(h)
272
+ else:
273
+ h = h + emb_out
274
+ h = self.out_layers(h)
275
+ return self.skip_connection(x) + h
276
+
277
+
278
+ class AttentionBlock(nn.Module):
279
+ """
280
+ An attention block that allows spatial positions to attend to each other.
281
+ Originally ported from here, but adapted to the N-d case.
282
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
283
+ """
284
+
285
+ def __init__(
286
+ self,
287
+ channels,
288
+ num_heads=1,
289
+ num_head_channels=-1,
290
+ use_checkpoint=False,
291
+ use_new_attention_order=False,
292
+ ):
293
+ super().__init__()
294
+ self.channels = channels
295
+ if num_head_channels == -1:
296
+ self.num_heads = num_heads
297
+ else:
298
+ assert (
299
+ channels % num_head_channels == 0
300
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
301
+ self.num_heads = channels // num_head_channels
302
+ self.use_checkpoint = use_checkpoint
303
+ self.norm = normalization(channels)
304
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
305
+ if use_new_attention_order:
306
+ # split qkv before split heads
307
+ self.attention = QKVAttention(self.num_heads)
308
+ else:
309
+ # split heads before split qkv
310
+ self.attention = QKVAttentionLegacy(self.num_heads)
311
+
312
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
313
+
314
+ def forward(self, x):
315
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
316
+ #return pt_checkpoint(self._forward, x) # pytorch
317
+
318
+ def _forward(self, x):
319
+ b, c, *spatial = x.shape
320
+ x = x.reshape(b, c, -1)
321
+ qkv = self.qkv(self.norm(x))
322
+ h = self.attention(qkv)
323
+ h = self.proj_out(h)
324
+ return (x + h).reshape(b, c, *spatial)
325
+
326
+
327
+ def count_flops_attn(model, _x, y):
328
+ """
329
+ A counter for the `thop` package to count the operations in an
330
+ attention operation.
331
+ Meant to be used like:
332
+ macs, params = thop.profile(
333
+ model,
334
+ inputs=(inputs, timestamps),
335
+ custom_ops={QKVAttention: QKVAttention.count_flops},
336
+ )
337
+ """
338
+ b, c, *spatial = y[0].shape
339
+ num_spatial = int(np.prod(spatial))
340
+ # We perform two matmuls with the same number of ops.
341
+ # The first computes the weight matrix, the second computes
342
+ # the combination of the value vectors.
343
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
344
+ model.total_ops += th.DoubleTensor([matmul_ops])
345
+
346
+
347
+ class QKVAttentionLegacy(nn.Module):
348
+ """
349
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
350
+ """
351
+
352
+ def __init__(self, n_heads):
353
+ super().__init__()
354
+ self.n_heads = n_heads
355
+
356
+ def forward(self, qkv):
357
+ """
358
+ Apply QKV attention.
359
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
360
+ :return: an [N x (H * C) x T] tensor after attention.
361
+ """
362
+ bs, width, length = qkv.shape
363
+ assert width % (3 * self.n_heads) == 0
364
+ ch = width // (3 * self.n_heads)
365
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
366
+ scale = 1 / math.sqrt(math.sqrt(ch))
367
+ weight = th.einsum(
368
+ "bct,bcs->bts", q * scale, k * scale
369
+ ) # More stable with f16 than dividing afterwards
370
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
371
+ a = th.einsum("bts,bcs->bct", weight, v)
372
+ return a.reshape(bs, -1, length)
373
+
374
+ @staticmethod
375
+ def count_flops(model, _x, y):
376
+ return count_flops_attn(model, _x, y)
377
+
378
+
379
+ class QKVAttention(nn.Module):
380
+ """
381
+ A module which performs QKV attention and splits in a different order.
382
+ """
383
+
384
+ def __init__(self, n_heads):
385
+ super().__init__()
386
+ self.n_heads = n_heads
387
+
388
+ def forward(self, qkv):
389
+ """
390
+ Apply QKV attention.
391
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
392
+ :return: an [N x (H * C) x T] tensor after attention.
393
+ """
394
+ bs, width, length = qkv.shape
395
+ assert width % (3 * self.n_heads) == 0
396
+ ch = width // (3 * self.n_heads)
397
+ q, k, v = qkv.chunk(3, dim=1)
398
+ scale = 1 / math.sqrt(math.sqrt(ch))
399
+ weight = th.einsum(
400
+ "bct,bcs->bts",
401
+ (q * scale).view(bs * self.n_heads, ch, length),
402
+ (k * scale).view(bs * self.n_heads, ch, length),
403
+ ) # More stable with f16 than dividing afterwards
404
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
405
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
406
+ return a.reshape(bs, -1, length)
407
+
408
+ @staticmethod
409
+ def count_flops(model, _x, y):
410
+ return count_flops_attn(model, _x, y)
411
+
412
+
413
+ class Timestep(nn.Module):
414
+ def __init__(self, dim):
415
+ super().__init__()
416
+ self.dim = dim
417
+
418
+ def forward(self, t):
419
+ return timestep_embedding(t, self.dim)
420
+
421
+
422
+ class UNetModel(nn.Module):
423
+ """
424
+ The full UNet model with attention and timestep embedding.
425
+ :param in_channels: channels in the input Tensor.
426
+ :param model_channels: base channel count for the model.
427
+ :param out_channels: channels in the output Tensor.
428
+ :param num_res_blocks: number of residual blocks per downsample.
429
+ :param attention_resolutions: a collection of downsample rates at which
430
+ attention will take place. May be a set, list, or tuple.
431
+ For example, if this contains 4, then at 4x downsampling, attention
432
+ will be used.
433
+ :param dropout: the dropout probability.
434
+ :param channel_mult: channel multiplier for each level of the UNet.
435
+ :param conv_resample: if True, use learned convolutions for upsampling and
436
+ downsampling.
437
+ :param dims: determines if the signal is 1D, 2D, or 3D.
438
+ :param num_classes: if specified (as an int), then this model will be
439
+ class-conditional with `num_classes` classes.
440
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
441
+ :param num_heads: the number of attention heads in each attention layer.
442
+ :param num_heads_channels: if specified, ignore num_heads and instead use
443
+ a fixed channel width per attention head.
444
+ :param num_heads_upsample: works with num_heads to set a different number
445
+ of heads for upsampling. Deprecated.
446
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
447
+ :param resblock_updown: use residual blocks for up/downsampling.
448
+ :param use_new_attention_order: use a different attention pattern for potentially
449
+ increased efficiency.
450
+ """
451
+
452
+ def __init__(
453
+ self,
454
+ image_size,
455
+ in_channels,
456
+ model_channels,
457
+ out_channels,
458
+ num_res_blocks,
459
+ attention_resolutions,
460
+ dropout=0,
461
+ channel_mult=(1, 2, 4, 8),
462
+ conv_resample=True,
463
+ dims=2,
464
+ num_classes=None,
465
+ use_checkpoint=False,
466
+ use_fp16=False,
467
+ use_bf16=False,
468
+ num_heads=-1,
469
+ num_head_channels=-1,
470
+ num_heads_upsample=-1,
471
+ use_scale_shift_norm=False,
472
+ resblock_updown=False,
473
+ use_new_attention_order=False,
474
+ use_spatial_transformer=False, # custom transformer support
475
+ transformer_depth=1, # custom transformer support
476
+ context_dim=None, # custom transformer support
477
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
478
+ legacy=True,
479
+ disable_self_attentions=None,
480
+ num_attention_blocks=None,
481
+ disable_middle_self_attn=False,
482
+ use_linear_in_transformer=False,
483
+ adm_in_channels=None,
484
+ ):
485
+ super().__init__()
486
+ if use_spatial_transformer:
487
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
488
+
489
+ if context_dim is not None:
490
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
491
+ from omegaconf.listconfig import ListConfig
492
+ if type(context_dim) == ListConfig:
493
+ context_dim = list(context_dim)
494
+
495
+ if num_heads_upsample == -1:
496
+ num_heads_upsample = num_heads
497
+
498
+ if num_heads == -1:
499
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
500
+
501
+ if num_head_channels == -1:
502
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
503
+
504
+ self.image_size = image_size
505
+ self.in_channels = in_channels
506
+ self.model_channels = model_channels
507
+ self.out_channels = out_channels
508
+ if isinstance(num_res_blocks, int):
509
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
510
+ else:
511
+ if len(num_res_blocks) != len(channel_mult):
512
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
513
+ "as a list/tuple (per-level) with the same length as channel_mult")
514
+ self.num_res_blocks = num_res_blocks
515
+ if disable_self_attentions is not None:
516
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
517
+ assert len(disable_self_attentions) == len(channel_mult)
518
+ if num_attention_blocks is not None:
519
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
520
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
521
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
522
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
523
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
524
+ f"attention will still not be set.")
525
+
526
+ self.attention_resolutions = attention_resolutions
527
+ self.dropout = dropout
528
+ self.channel_mult = channel_mult
529
+ self.conv_resample = conv_resample
530
+ self.num_classes = num_classes
531
+ self.use_checkpoint = use_checkpoint
532
+ self.dtype = th.float16 if use_fp16 else th.float32
533
+ self.dtype = th.bfloat16 if use_bf16 else self.dtype
534
+ self.num_heads = num_heads
535
+ self.num_head_channels = num_head_channels
536
+ self.num_heads_upsample = num_heads_upsample
537
+ self.predict_codebook_ids = n_embed is not None
538
+
539
+ time_embed_dim = model_channels * 4
540
+ self.time_embed = nn.Sequential(
541
+ linear(model_channels, time_embed_dim),
542
+ nn.SiLU(),
543
+ linear(time_embed_dim, time_embed_dim),
544
+ )
545
+
546
+ if self.num_classes is not None:
547
+ if isinstance(self.num_classes, int):
548
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
549
+ elif self.num_classes == "continuous":
550
+ print("setting up linear c_adm embedding layer")
551
+ self.label_emb = nn.Linear(1, time_embed_dim)
552
+ elif self.num_classes == "sequential":
553
+ assert adm_in_channels is not None
554
+ self.label_emb = nn.Sequential(
555
+ nn.Sequential(
556
+ linear(adm_in_channels, time_embed_dim),
557
+ nn.SiLU(),
558
+ linear(time_embed_dim, time_embed_dim),
559
+ )
560
+ )
561
+ else:
562
+ raise ValueError()
563
+
564
+ self.input_blocks = nn.ModuleList(
565
+ [
566
+ TimestepEmbedSequential(
567
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
568
+ )
569
+ ]
570
+ )
571
+ self._feature_size = model_channels
572
+ input_block_chans = [model_channels]
573
+ ch = model_channels
574
+ ds = 1
575
+ for level, mult in enumerate(channel_mult):
576
+ for nr in range(self.num_res_blocks[level]):
577
+ layers = [
578
+ ResBlock(
579
+ ch,
580
+ time_embed_dim,
581
+ dropout,
582
+ out_channels=mult * model_channels,
583
+ dims=dims,
584
+ use_checkpoint=use_checkpoint,
585
+ use_scale_shift_norm=use_scale_shift_norm,
586
+ )
587
+ ]
588
+ ch = mult * model_channels
589
+ if ds in attention_resolutions:
590
+ if num_head_channels == -1:
591
+ dim_head = ch // num_heads
592
+ else:
593
+ num_heads = ch // num_head_channels
594
+ dim_head = num_head_channels
595
+ if legacy:
596
+ #num_heads = 1
597
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
598
+ if exists(disable_self_attentions):
599
+ disabled_sa = disable_self_attentions[level]
600
+ else:
601
+ disabled_sa = False
602
+
603
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
604
+ layers.append(
605
+ AttentionBlock(
606
+ ch,
607
+ use_checkpoint=use_checkpoint,
608
+ num_heads=num_heads,
609
+ num_head_channels=dim_head,
610
+ use_new_attention_order=use_new_attention_order,
611
+ ) if not use_spatial_transformer else SpatialTransformer(
612
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
613
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
614
+ use_checkpoint=use_checkpoint
615
+ )
616
+ )
617
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
618
+ self._feature_size += ch
619
+ input_block_chans.append(ch)
620
+ if level != len(channel_mult) - 1:
621
+ out_ch = ch
622
+ self.input_blocks.append(
623
+ TimestepEmbedSequential(
624
+ ResBlock(
625
+ ch,
626
+ time_embed_dim,
627
+ dropout,
628
+ out_channels=out_ch,
629
+ dims=dims,
630
+ use_checkpoint=use_checkpoint,
631
+ use_scale_shift_norm=use_scale_shift_norm,
632
+ down=True,
633
+ )
634
+ if resblock_updown
635
+ else Downsample(
636
+ ch, conv_resample, dims=dims, out_channels=out_ch
637
+ )
638
+ )
639
+ )
640
+ ch = out_ch
641
+ input_block_chans.append(ch)
642
+ ds *= 2
643
+ self._feature_size += ch
644
+
645
+ if num_head_channels == -1:
646
+ dim_head = ch // num_heads
647
+ else:
648
+ num_heads = ch // num_head_channels
649
+ dim_head = num_head_channels
650
+ if legacy:
651
+ #num_heads = 1
652
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
653
+ self.middle_block = TimestepEmbedSequential(
654
+ ResBlock(
655
+ ch,
656
+ time_embed_dim,
657
+ dropout,
658
+ dims=dims,
659
+ use_checkpoint=use_checkpoint,
660
+ use_scale_shift_norm=use_scale_shift_norm,
661
+ ),
662
+ AttentionBlock(
663
+ ch,
664
+ use_checkpoint=use_checkpoint,
665
+ num_heads=num_heads,
666
+ num_head_channels=dim_head,
667
+ use_new_attention_order=use_new_attention_order,
668
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
669
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
670
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
671
+ use_checkpoint=use_checkpoint
672
+ ),
673
+ ResBlock(
674
+ ch,
675
+ time_embed_dim,
676
+ dropout,
677
+ dims=dims,
678
+ use_checkpoint=use_checkpoint,
679
+ use_scale_shift_norm=use_scale_shift_norm,
680
+ ),
681
+ )
682
+ self._feature_size += ch
683
+
684
+ self.output_blocks = nn.ModuleList([])
685
+ for level, mult in list(enumerate(channel_mult))[::-1]:
686
+ for i in range(self.num_res_blocks[level] + 1):
687
+ ich = input_block_chans.pop()
688
+ layers = [
689
+ ResBlock(
690
+ ch + ich,
691
+ time_embed_dim,
692
+ dropout,
693
+ out_channels=model_channels * mult,
694
+ dims=dims,
695
+ use_checkpoint=use_checkpoint,
696
+ use_scale_shift_norm=use_scale_shift_norm,
697
+ )
698
+ ]
699
+ ch = model_channels * mult
700
+ if ds in attention_resolutions:
701
+ if num_head_channels == -1:
702
+ dim_head = ch // num_heads
703
+ else:
704
+ num_heads = ch // num_head_channels
705
+ dim_head = num_head_channels
706
+ if legacy:
707
+ #num_heads = 1
708
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
709
+ if exists(disable_self_attentions):
710
+ disabled_sa = disable_self_attentions[level]
711
+ else:
712
+ disabled_sa = False
713
+
714
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
715
+ layers.append(
716
+ AttentionBlock(
717
+ ch,
718
+ use_checkpoint=use_checkpoint,
719
+ num_heads=num_heads_upsample,
720
+ num_head_channels=dim_head,
721
+ use_new_attention_order=use_new_attention_order,
722
+ ) if not use_spatial_transformer else SpatialTransformer(
723
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
724
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
725
+ use_checkpoint=use_checkpoint
726
+ )
727
+ )
728
+ if level and i == self.num_res_blocks[level]:
729
+ out_ch = ch
730
+ layers.append(
731
+ ResBlock(
732
+ ch,
733
+ time_embed_dim,
734
+ dropout,
735
+ out_channels=out_ch,
736
+ dims=dims,
737
+ use_checkpoint=use_checkpoint,
738
+ use_scale_shift_norm=use_scale_shift_norm,
739
+ up=True,
740
+ )
741
+ if resblock_updown
742
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
743
+ )
744
+ ds //= 2
745
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
746
+ self._feature_size += ch
747
+
748
+ self.out = nn.Sequential(
749
+ normalization(ch),
750
+ nn.SiLU(),
751
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
752
+ )
753
+ if self.predict_codebook_ids:
754
+ self.id_predictor = nn.Sequential(
755
+ normalization(ch),
756
+ conv_nd(dims, model_channels, n_embed, 1),
757
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
758
+ )
759
+
760
+ def convert_to_fp16(self):
761
+ """
762
+ Convert the torso of the model to float16.
763
+ """
764
+ self.input_blocks.apply(convert_module_to_f16)
765
+ self.middle_block.apply(convert_module_to_f16)
766
+ self.output_blocks.apply(convert_module_to_f16)
767
+
768
+ def convert_to_fp32(self):
769
+ """
770
+ Convert the torso of the model to float32.
771
+ """
772
+ self.input_blocks.apply(convert_module_to_f32)
773
+ self.middle_block.apply(convert_module_to_f32)
774
+ self.output_blocks.apply(convert_module_to_f32)
775
+
776
+ def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
777
+ """
778
+ Apply the model to an input batch.
779
+ :param x: an [N x C x ...] Tensor of inputs.
780
+ :param timesteps: a 1-D batch of timesteps.
781
+ :param context: conditioning plugged in via crossattn
782
+ :param y: an [N] Tensor of labels, if class-conditional.
783
+ :return: an [N x C x ...] Tensor of outputs.
784
+ """
785
+ assert (y is not None) == (
786
+ self.num_classes is not None
787
+ ), "must specify y if and only if the model is class-conditional"
788
+ hs = []
789
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
790
+ emb = self.time_embed(t_emb)
791
+
792
+ if self.num_classes is not None:
793
+ assert y.shape[0] == x.shape[0]
794
+ emb = emb + self.label_emb(y)
795
+
796
+ h = x.type(self.dtype)
797
+ for module in self.input_blocks:
798
+ h = module(h, emb, context)
799
+ hs.append(h)
800
+ h = self.middle_block(h, emb, context)
801
+ for module in self.output_blocks:
802
+ h = th.cat([h, hs.pop()], dim=1)
803
+ h = module(h, emb, context)
804
+ h = h.type(x.dtype)
805
+ if self.predict_codebook_ids:
806
+ return self.id_predictor(h)
807
+ else:
808
+ return self.out(h)
809
+
810
+
811
+ class MultiViewUNetModel(nn.Module):
812
+ """
813
+ The full multi-view UNet model with attention, timestep embedding and camera embedding.
814
+ :param in_channels: channels in the input Tensor.
815
+ :param model_channels: base channel count for the model.
816
+ :param out_channels: channels in the output Tensor.
817
+ :param num_res_blocks: number of residual blocks per downsample.
818
+ :param attention_resolutions: a collection of downsample rates at which
819
+ attention will take place. May be a set, list, or tuple.
820
+ For example, if this contains 4, then at 4x downsampling, attention
821
+ will be used.
822
+ :param dropout: the dropout probability.
823
+ :param channel_mult: channel multiplier for each level of the UNet.
824
+ :param conv_resample: if True, use learned convolutions for upsampling and
825
+ downsampling.
826
+ :param dims: determines if the signal is 1D, 2D, or 3D.
827
+ :param num_classes: if specified (as an int), then this model will be
828
+ class-conditional with `num_classes` classes.
829
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
830
+ :param num_heads: the number of attention heads in each attention layer.
831
+ :param num_heads_channels: if specified, ignore num_heads and instead use
832
+ a fixed channel width per attention head.
833
+ :param num_heads_upsample: works with num_heads to set a different number
834
+ of heads for upsampling. Deprecated.
835
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
836
+ :param resblock_updown: use residual blocks for up/downsampling.
837
+ :param use_new_attention_order: use a different attention pattern for potentially
838
+ increased efficiency.
839
+ :param camera_dim: dimensionality of camera input.
840
+ """
841
+
842
+ def __init__(
843
+ self,
844
+ image_size,
845
+ in_channels,
846
+ model_channels,
847
+ out_channels,
848
+ num_res_blocks,
849
+ attention_resolutions,
850
+ dropout=0,
851
+ channel_mult=(1, 2, 4, 8),
852
+ conv_resample=True,
853
+ dims=2,
854
+ num_classes=None,
855
+ use_checkpoint=False,
856
+ use_fp16=False,
857
+ use_bf16=False,
858
+ num_heads=-1,
859
+ num_head_channels=-1,
860
+ num_heads_upsample=-1,
861
+ use_scale_shift_norm=False,
862
+ resblock_updown=False,
863
+ use_new_attention_order=False,
864
+ use_spatial_transformer=False, # custom transformer support
865
+ transformer_depth=1, # custom transformer support
866
+ context_dim=None, # custom transformer support
867
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
868
+ legacy=True,
869
+ disable_self_attentions=None,
870
+ num_attention_blocks=None,
871
+ disable_middle_self_attn=False,
872
+ use_linear_in_transformer=False,
873
+ adm_in_channels=None,
874
+ camera_dim=None,
875
+ ):
876
+ super().__init__()
877
+ if use_spatial_transformer:
878
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
879
+
880
+ if context_dim is not None:
881
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
882
+ from omegaconf.listconfig import ListConfig
883
+ if type(context_dim) == ListConfig:
884
+ context_dim = list(context_dim)
885
+
886
+ if num_heads_upsample == -1:
887
+ num_heads_upsample = num_heads
888
+
889
+ if num_heads == -1:
890
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
891
+
892
+ if num_head_channels == -1:
893
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
894
+
895
+ self.image_size = image_size
896
+ self.in_channels = in_channels
897
+ self.model_channels = model_channels
898
+ self.out_channels = out_channels
899
+ if isinstance(num_res_blocks, int):
900
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
901
+ else:
902
+ if len(num_res_blocks) != len(channel_mult):
903
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
904
+ "as a list/tuple (per-level) with the same length as channel_mult")
905
+ self.num_res_blocks = num_res_blocks
906
+ if disable_self_attentions is not None:
907
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
908
+ assert len(disable_self_attentions) == len(channel_mult)
909
+ if num_attention_blocks is not None:
910
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
911
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
912
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
913
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
914
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
915
+ f"attention will still not be set.")
916
+
917
+ self.attention_resolutions = attention_resolutions
918
+ self.dropout = dropout
919
+ self.channel_mult = channel_mult
920
+ self.conv_resample = conv_resample
921
+ self.num_classes = num_classes
922
+ self.use_checkpoint = use_checkpoint
923
+ self.dtype = th.float16 if use_fp16 else th.float32
924
+ self.dtype = th.bfloat16 if use_bf16 else self.dtype
925
+ self.num_heads = num_heads
926
+ self.num_head_channels = num_head_channels
927
+ self.num_heads_upsample = num_heads_upsample
928
+ self.predict_codebook_ids = n_embed is not None
929
+
930
+ time_embed_dim = model_channels * 4
931
+ self.time_embed = nn.Sequential(
932
+ linear(model_channels, time_embed_dim),
933
+ nn.SiLU(),
934
+ linear(time_embed_dim, time_embed_dim),
935
+ )
936
+
937
+ if camera_dim is not None:
938
+ time_embed_dim = model_channels * 4
939
+ self.camera_embed = nn.Sequential(
940
+ linear(camera_dim, time_embed_dim),
941
+ nn.SiLU(),
942
+ linear(time_embed_dim, time_embed_dim),
943
+ )
944
+
945
+ if self.num_classes is not None:
946
+ if isinstance(self.num_classes, int):
947
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
948
+ elif self.num_classes == "continuous":
949
+ print("setting up linear c_adm embedding layer")
950
+ self.label_emb = nn.Linear(1, time_embed_dim)
951
+ elif self.num_classes == "sequential":
952
+ assert adm_in_channels is not None
953
+ self.label_emb = nn.Sequential(
954
+ nn.Sequential(
955
+ linear(adm_in_channels, time_embed_dim),
956
+ nn.SiLU(),
957
+ linear(time_embed_dim, time_embed_dim),
958
+ )
959
+ )
960
+ else:
961
+ raise ValueError()
962
+
963
+ self.input_blocks = nn.ModuleList(
964
+ [
965
+ TimestepEmbedSequential(
966
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
967
+ )
968
+ ]
969
+ )
970
+ self._feature_size = model_channels
971
+ input_block_chans = [model_channels]
972
+ ch = model_channels
973
+ ds = 1
974
+ for level, mult in enumerate(channel_mult):
975
+ for nr in range(self.num_res_blocks[level]):
976
+ layers = [
977
+ ResBlock(
978
+ ch,
979
+ time_embed_dim,
980
+ dropout,
981
+ out_channels=mult * model_channels,
982
+ dims=dims,
983
+ use_checkpoint=use_checkpoint,
984
+ use_scale_shift_norm=use_scale_shift_norm,
985
+ )
986
+ ]
987
+ ch = mult * model_channels
988
+ if ds in attention_resolutions:
989
+ if num_head_channels == -1:
990
+ dim_head = ch // num_heads
991
+ else:
992
+ num_heads = ch // num_head_channels
993
+ dim_head = num_head_channels
994
+ if legacy:
995
+ #num_heads = 1
996
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
997
+ if exists(disable_self_attentions):
998
+ disabled_sa = disable_self_attentions[level]
999
+ else:
1000
+ disabled_sa = False
1001
+
1002
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
1003
+ layers.append(
1004
+ AttentionBlock(
1005
+ ch,
1006
+ use_checkpoint=use_checkpoint,
1007
+ num_heads=num_heads,
1008
+ num_head_channels=dim_head,
1009
+ use_new_attention_order=use_new_attention_order,
1010
+ ) if not use_spatial_transformer else SpatialTransformer3D(
1011
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
1012
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
1013
+ use_checkpoint=use_checkpoint
1014
+ )
1015
+ )
1016
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
1017
+ self._feature_size += ch
1018
+ input_block_chans.append(ch)
1019
+ if level != len(channel_mult) - 1:
1020
+ out_ch = ch
1021
+ self.input_blocks.append(
1022
+ TimestepEmbedSequential(
1023
+ ResBlock(
1024
+ ch,
1025
+ time_embed_dim,
1026
+ dropout,
1027
+ out_channels=out_ch,
1028
+ dims=dims,
1029
+ use_checkpoint=use_checkpoint,
1030
+ use_scale_shift_norm=use_scale_shift_norm,
1031
+ down=True,
1032
+ )
1033
+ if resblock_updown
1034
+ else Downsample(
1035
+ ch, conv_resample, dims=dims, out_channels=out_ch
1036
+ )
1037
+ )
1038
+ )
1039
+ ch = out_ch
1040
+ input_block_chans.append(ch)
1041
+ ds *= 2
1042
+ self._feature_size += ch
1043
+
1044
+ if num_head_channels == -1:
1045
+ dim_head = ch // num_heads
1046
+ else:
1047
+ num_heads = ch // num_head_channels
1048
+ dim_head = num_head_channels
1049
+ if legacy:
1050
+ #num_heads = 1
1051
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
1052
+ self.middle_block = TimestepEmbedSequential(
1053
+ ResBlock(
1054
+ ch,
1055
+ time_embed_dim,
1056
+ dropout,
1057
+ dims=dims,
1058
+ use_checkpoint=use_checkpoint,
1059
+ use_scale_shift_norm=use_scale_shift_norm,
1060
+ ),
1061
+ AttentionBlock(
1062
+ ch,
1063
+ use_checkpoint=use_checkpoint,
1064
+ num_heads=num_heads,
1065
+ num_head_channels=dim_head,
1066
+ use_new_attention_order=use_new_attention_order,
1067
+ ) if not use_spatial_transformer else SpatialTransformer3D( # always uses a self-attn
1068
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
1069
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
1070
+ use_checkpoint=use_checkpoint
1071
+ ),
1072
+ ResBlock(
1073
+ ch,
1074
+ time_embed_dim,
1075
+ dropout,
1076
+ dims=dims,
1077
+ use_checkpoint=use_checkpoint,
1078
+ use_scale_shift_norm=use_scale_shift_norm,
1079
+ ),
1080
+ )
1081
+ self._feature_size += ch
1082
+
1083
+ self.output_blocks = nn.ModuleList([])
1084
+ for level, mult in list(enumerate(channel_mult))[::-1]:
1085
+ for i in range(self.num_res_blocks[level] + 1):
1086
+ ich = input_block_chans.pop()
1087
+ layers = [
1088
+ ResBlock(
1089
+ ch + ich,
1090
+ time_embed_dim,
1091
+ dropout,
1092
+ out_channels=model_channels * mult,
1093
+ dims=dims,
1094
+ use_checkpoint=use_checkpoint,
1095
+ use_scale_shift_norm=use_scale_shift_norm,
1096
+ )
1097
+ ]
1098
+ ch = model_channels * mult
1099
+ if ds in attention_resolutions:
1100
+ if num_head_channels == -1:
1101
+ dim_head = ch // num_heads
1102
+ else:
1103
+ num_heads = ch // num_head_channels
1104
+ dim_head = num_head_channels
1105
+ if legacy:
1106
+ #num_heads = 1
1107
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
1108
+ if exists(disable_self_attentions):
1109
+ disabled_sa = disable_self_attentions[level]
1110
+ else:
1111
+ disabled_sa = False
1112
+
1113
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
1114
+ layers.append(
1115
+ AttentionBlock(
1116
+ ch,
1117
+ use_checkpoint=use_checkpoint,
1118
+ num_heads=num_heads_upsample,
1119
+ num_head_channels=dim_head,
1120
+ use_new_attention_order=use_new_attention_order,
1121
+ ) if not use_spatial_transformer else SpatialTransformer3D(
1122
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
1123
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
1124
+ use_checkpoint=use_checkpoint
1125
+ )
1126
+ )
1127
+ if level and i == self.num_res_blocks[level]:
1128
+ out_ch = ch
1129
+ layers.append(
1130
+ ResBlock(
1131
+ ch,
1132
+ time_embed_dim,
1133
+ dropout,
1134
+ out_channels=out_ch,
1135
+ dims=dims,
1136
+ use_checkpoint=use_checkpoint,
1137
+ use_scale_shift_norm=use_scale_shift_norm,
1138
+ up=True,
1139
+ )
1140
+ if resblock_updown
1141
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
1142
+ )
1143
+ ds //= 2
1144
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
1145
+ self._feature_size += ch
1146
+
1147
+ self.out = nn.Sequential(
1148
+ normalization(ch),
1149
+ nn.SiLU(),
1150
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
1151
+ )
1152
+ if self.predict_codebook_ids:
1153
+ self.id_predictor = nn.Sequential(
1154
+ normalization(ch),
1155
+ conv_nd(dims, model_channels, n_embed, 1),
1156
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
1157
+ )
1158
+
1159
+ def convert_to_fp16(self):
1160
+ """
1161
+ Convert the torso of the model to float16.
1162
+ """
1163
+ self.input_blocks.apply(convert_module_to_f16)
1164
+ self.middle_block.apply(convert_module_to_f16)
1165
+ self.output_blocks.apply(convert_module_to_f16)
1166
+
1167
+ def convert_to_fp32(self):
1168
+ """
1169
+ Convert the torso of the model to float32.
1170
+ """
1171
+ self.input_blocks.apply(convert_module_to_f32)
1172
+ self.middle_block.apply(convert_module_to_f32)
1173
+ self.output_blocks.apply(convert_module_to_f32)
1174
+
1175
+ def forward(self, x, timesteps=None, context=None, y=None, camera=None, num_frames=1, **kwargs):
1176
+ """
1177
+ Apply the model to an input batch.
1178
+ :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
1179
+ :param timesteps: a 1-D batch of timesteps.
1180
+ :param context: conditioning plugged in via crossattn
1181
+ :param y: an [N] Tensor of labels, if class-conditional.
1182
+ :param num_frames: a integer indicating number of frames for tensor reshaping.
1183
+ :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
1184
+ """
1185
+ assert x.shape[0] % num_frames == 0, "[UNet] input batch size must be dividable by num_frames!"
1186
+ assert (y is not None) == (
1187
+ self.num_classes is not None
1188
+ ), "must specify y if and only if the model is class-conditional"
1189
+ hs = []
1190
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
1191
+ emb = self.time_embed(t_emb)
1192
+
1193
+ if self.num_classes is not None:
1194
+ assert y.shape[0] == x.shape[0]
1195
+ emb = emb + self.label_emb(y)
1196
+
1197
+ # Add camera embeddings
1198
+ if camera is not None:
1199
+ assert camera.shape[0] == emb.shape[0]
1200
+ emb = emb + self.camera_embed(camera)
1201
+
1202
+ h = x.type(self.dtype)
1203
+ for module in self.input_blocks:
1204
+ h = module(h, emb, context, num_frames=num_frames)
1205
+ hs.append(h)
1206
+ h = self.middle_block(h, emb, context, num_frames=num_frames)
1207
+ for module in self.output_blocks:
1208
+ h = th.cat([h, hs.pop()], dim=1)
1209
+ h = module(h, emb, context, num_frames=num_frames)
1210
+ h = h.type(x.dtype)
1211
+ if self.predict_codebook_ids:
1212
+ return self.id_predictor(h)
1213
+ else:
1214
+ return self.out(h)
scripts/pipeline_mvdream.py ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+
4
+ import torch
5
+ from packaging import version
6
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
7
+
8
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DiffusionPipeline
9
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
10
+ from diffusers.utils import (
11
+ deprecate,
12
+ is_accelerate_available,
13
+ is_accelerate_version,
14
+ logging,
15
+ replace_example_docstring,
16
+ )
17
+
18
+ try:
19
+ from diffusers import randn_tensor # old import
20
+ except ImportError:
21
+ from diffusers.utils.torch_utils import randn_tensor # new import
22
+
23
+ from diffusers.configuration_utils import FrozenDict
24
+ import PIL
25
+ import numpy as np
26
+ import kornia
27
+ from diffusers.configuration_utils import ConfigMixin
28
+ from diffusers.models.modeling_utils import ModelMixin
29
+
30
+ from models import MultiViewUNetModel
31
+ from diffusers.schedulers import DDIMScheduler
32
+
33
+ EXAMPLE_DOC_STRING = ""
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+
38
+ import numpy as np
39
+ def create_camera_to_world_matrix(elevation, azimuth):
40
+ elevation = np.radians(elevation)
41
+ azimuth = np.radians(azimuth)
42
+ # Convert elevation and azimuth angles to Cartesian coordinates on a unit sphere
43
+ x = np.cos(elevation) * np.sin(azimuth)
44
+ y = np.sin(elevation)
45
+ z = np.cos(elevation) * np.cos(azimuth)
46
+
47
+ # Calculate camera position, target, and up vectors
48
+ camera_pos = np.array([x, y, z])
49
+ target = np.array([0, 0, 0])
50
+ up = np.array([0, 1, 0])
51
+
52
+ # Construct view matrix
53
+ forward = target - camera_pos
54
+ forward /= np.linalg.norm(forward)
55
+ right = np.cross(forward, up)
56
+ right /= np.linalg.norm(right)
57
+ new_up = np.cross(right, forward)
58
+ new_up /= np.linalg.norm(new_up)
59
+ cam2world = np.eye(4)
60
+ cam2world[:3, :3] = np.array([right, new_up, -forward]).T
61
+ cam2world[:3, 3] = camera_pos
62
+ return cam2world
63
+
64
+ def convert_opengl_to_blender(camera_matrix):
65
+ if isinstance(camera_matrix, np.ndarray):
66
+ # Construct transformation matrix to convert from OpenGL space to Blender space
67
+ flip_yz = np.array([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
68
+ camera_matrix_blender = np.dot(flip_yz, camera_matrix)
69
+ else:
70
+ # Construct transformation matrix to convert from OpenGL space to Blender space
71
+ flip_yz = torch.tensor([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
72
+ if camera_matrix.ndim == 3:
73
+ flip_yz = flip_yz.unsqueeze(0)
74
+ camera_matrix_blender = torch.matmul(flip_yz.to(camera_matrix), camera_matrix)
75
+ return camera_matrix_blender
76
+
77
+ def get_camera(num_frames, elevation=15, azimuth_start=0, azimuth_span=360, blender_coord=True):
78
+ angle_gap = azimuth_span / num_frames
79
+ cameras = []
80
+ for azimuth in np.arange(azimuth_start, azimuth_span+azimuth_start, angle_gap):
81
+ camera_matrix = create_camera_to_world_matrix(elevation, azimuth)
82
+ if blender_coord:
83
+ camera_matrix = convert_opengl_to_blender(camera_matrix)
84
+ cameras.append(camera_matrix.flatten())
85
+ return torch.tensor(np.stack(cameras, 0)).float()
86
+
87
+ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
88
+ def __init__(
89
+ self,
90
+ vae: AutoencoderKL,
91
+ unet: MultiViewUNetModel,
92
+ tokenizer: CLIPTokenizer,
93
+ text_encoder: CLIPTextModel,
94
+ scheduler: DDIMScheduler,
95
+ safety_checker: Optional[StableDiffusionSafetyChecker],
96
+ feature_extractor: Optional[CLIPFeatureExtractor],
97
+ requires_safety_checker: bool = True,
98
+ ):
99
+ super().__init__()
100
+
101
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
102
+ deprecation_message = (f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
103
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
104
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
105
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
106
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
107
+ " file")
108
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
109
+ new_config = dict(scheduler.config)
110
+ new_config["steps_offset"] = 1
111
+ scheduler._internal_dict = FrozenDict(new_config)
112
+
113
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
114
+ deprecation_message = (f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
115
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
116
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
117
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
118
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file")
119
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
120
+ new_config = dict(scheduler.config)
121
+ new_config["clip_sample"] = False
122
+ scheduler._internal_dict = FrozenDict(new_config)
123
+
124
+ if safety_checker is None and requires_safety_checker:
125
+ logger.warning(f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
126
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
127
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
128
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
129
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
130
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 .")
131
+
132
+ if safety_checker is not None and feature_extractor is None:
133
+ raise ValueError("Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
134
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.")
135
+
136
+ self.register_modules(
137
+ vae=vae,
138
+ unet=unet,
139
+ scheduler=scheduler,
140
+ tokenizer=tokenizer,
141
+ text_encoder=text_encoder,
142
+ safety_checker=safety_checker,
143
+ feature_extractor=feature_extractor,
144
+ )
145
+ self.vae_scale_factor = 2**(len(self.vae.config.block_out_channels) - 1)
146
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
147
+ # self.model_mode = None
148
+
149
+ def enable_vae_slicing(self):
150
+ r"""
151
+ Enable sliced VAE decoding.
152
+
153
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
154
+ steps. This is useful to save some memory and allow larger batch sizes.
155
+ """
156
+ self.vae.enable_slicing()
157
+
158
+ def disable_vae_slicing(self):
159
+ r"""
160
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
161
+ computing decoding in one step.
162
+ """
163
+ self.vae.disable_slicing()
164
+
165
+ def enable_vae_tiling(self):
166
+ r"""
167
+ Enable tiled VAE decoding.
168
+
169
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
170
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
171
+ """
172
+ self.vae.enable_tiling()
173
+
174
+ def disable_vae_tiling(self):
175
+ r"""
176
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
177
+ computing decoding in one step.
178
+ """
179
+ self.vae.disable_tiling()
180
+
181
+ def enable_sequential_cpu_offload(self, gpu_id=0):
182
+ r"""
183
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
184
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
185
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
186
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
187
+ `enable_model_cpu_offload`, but performance is lower.
188
+ """
189
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
190
+ from accelerate import cpu_offload
191
+ else:
192
+ raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
193
+
194
+ device = torch.device(f"cuda:{gpu_id}")
195
+
196
+ if self.device.type != "cpu":
197
+ self.to("cpu", silence_dtype_warnings=True)
198
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
199
+
200
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
201
+ cpu_offload(cpu_offloaded_model, device)
202
+
203
+ if self.safety_checker is not None:
204
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
205
+
206
+ def enable_model_cpu_offload(self, gpu_id=0):
207
+ r"""
208
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
209
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
210
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
211
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
212
+ """
213
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
214
+ from accelerate import cpu_offload_with_hook
215
+ else:
216
+ raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
217
+
218
+ device = torch.device(f"cuda:{gpu_id}")
219
+
220
+ if self.device.type != "cpu":
221
+ self.to("cpu", silence_dtype_warnings=True)
222
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
223
+
224
+ hook = None
225
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
226
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
227
+
228
+ if self.safety_checker is not None:
229
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
230
+
231
+ # We'll offload the last model manually.
232
+ self.final_offload_hook = hook
233
+
234
+ @property
235
+ def _execution_device(self):
236
+ r"""
237
+ Returns the device on which the pipeline's models will be executed. After calling
238
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
239
+ hooks.
240
+ """
241
+ if not hasattr(self.unet, "_hf_hook"):
242
+ return self.device
243
+ for module in self.unet.modules():
244
+ if (hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None):
245
+ return torch.device(module._hf_hook.execution_device)
246
+ return self.device
247
+
248
+ def _encode_prompt(
249
+ self,
250
+ prompt,
251
+ device,
252
+ num_images_per_prompt,
253
+ do_classifier_free_guidance: bool,
254
+ negative_prompt=None,
255
+ prompt_embeds: Optional[torch.FloatTensor] = None,
256
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
257
+ ):
258
+ r"""
259
+ Encodes the prompt into text encoder hidden states.
260
+
261
+ Args:
262
+ prompt (`str` or `List[str]`, *optional*):
263
+ prompt to be encoded
264
+ device: (`torch.device`):
265
+ torch device
266
+ num_images_per_prompt (`int`):
267
+ number of images that should be generated per prompt
268
+ do_classifier_free_guidance (`bool`):
269
+ whether to use classifier free guidance or not
270
+ negative_prompt (`str` or `List[str]`, *optional*):
271
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
272
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
273
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
274
+ prompt_embeds (`torch.FloatTensor`, *optional*):
275
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
276
+ provided, text embeddings will be generated from `prompt` input argument.
277
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
278
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
279
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
280
+ argument.
281
+ """
282
+ if prompt is not None and isinstance(prompt, str):
283
+ batch_size = 1
284
+ elif prompt is not None and isinstance(prompt, list):
285
+ batch_size = len(prompt)
286
+ else:
287
+ batch_size = prompt_embeds.shape[0]
288
+
289
+ if prompt_embeds is None:
290
+ text_inputs = self.tokenizer(
291
+ prompt,
292
+ padding="max_length",
293
+ max_length=self.tokenizer.model_max_length,
294
+ truncation=True,
295
+ return_tensors="pt",
296
+ )
297
+ text_input_ids = text_inputs.input_ids
298
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
299
+
300
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
301
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1:-1])
302
+ logger.warning("The following part of your input was truncated because CLIP can only handle sequences up to"
303
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}")
304
+
305
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
306
+ attention_mask = text_inputs.attention_mask.to(device)
307
+ else:
308
+ attention_mask = None
309
+
310
+ prompt_embeds = self.text_encoder(
311
+ text_input_ids.to(device),
312
+ attention_mask=attention_mask,
313
+ )
314
+ prompt_embeds = prompt_embeds[0]
315
+
316
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
317
+
318
+ bs_embed, seq_len, _ = prompt_embeds.shape
319
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
320
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
321
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
322
+
323
+ # get unconditional embeddings for classifier free guidance
324
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
325
+ uncond_tokens: List[str]
326
+ if negative_prompt is None:
327
+ uncond_tokens = [""] * batch_size
328
+ elif type(prompt) is not type(negative_prompt):
329
+ raise TypeError(f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
330
+ f" {type(prompt)}.")
331
+ elif isinstance(negative_prompt, str):
332
+ uncond_tokens = [negative_prompt]
333
+ elif batch_size != len(negative_prompt):
334
+ raise ValueError(f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
335
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
336
+ " the batch size of `prompt`.")
337
+ else:
338
+ uncond_tokens = negative_prompt
339
+
340
+ max_length = prompt_embeds.shape[1]
341
+ uncond_input = self.tokenizer(
342
+ uncond_tokens,
343
+ padding="max_length",
344
+ max_length=max_length,
345
+ truncation=True,
346
+ return_tensors="pt",
347
+ )
348
+
349
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
350
+ attention_mask = uncond_input.attention_mask.to(device)
351
+ else:
352
+ attention_mask = None
353
+
354
+ negative_prompt_embeds = self.text_encoder(
355
+ uncond_input.input_ids.to(device),
356
+ attention_mask=attention_mask,
357
+ )
358
+ negative_prompt_embeds = negative_prompt_embeds[0]
359
+
360
+ if do_classifier_free_guidance:
361
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
362
+ seq_len = negative_prompt_embeds.shape[1]
363
+
364
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
365
+
366
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
367
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
368
+
369
+ # For classifier free guidance, we need to do two forward passes.
370
+ # Here we concatenate the unconditional and text embeddings into a single batch
371
+ # to avoid doing two forward passes
372
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
373
+
374
+ return prompt_embeds
375
+
376
+ def run_safety_checker(self, image, device, dtype):
377
+ if self.safety_checker is not None:
378
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
379
+ image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))
380
+ else:
381
+ has_nsfw_concept = None
382
+ return image, has_nsfw_concept
383
+
384
+ def decode_latents(self, latents):
385
+ latents = 1 / self.vae.config.scaling_factor * latents
386
+ image = self.vae.decode(latents).sample
387
+ image = (image / 2 + 0.5).clamp(0, 1)
388
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
389
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
390
+ return image
391
+
392
+ def prepare_extra_step_kwargs(self, generator, eta):
393
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
394
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
395
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
396
+ # and should be between [0, 1]
397
+
398
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
399
+ extra_step_kwargs = {}
400
+ if accepts_eta:
401
+ extra_step_kwargs["eta"] = eta
402
+
403
+ # check if the scheduler accepts generator
404
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
405
+ if accepts_generator:
406
+ extra_step_kwargs["generator"] = generator
407
+ return extra_step_kwargs
408
+
409
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
410
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
411
+ if isinstance(generator, list) and len(generator) != batch_size:
412
+ raise ValueError(f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
413
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators.")
414
+
415
+ if latents is None:
416
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
417
+ else:
418
+ latents = latents.to(device)
419
+
420
+ # scale the initial noise by the standard deviation required by the scheduler
421
+ latents = latents * self.scheduler.init_noise_sigma
422
+ return latents
423
+
424
+ @torch.no_grad()
425
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
426
+ def __call__(
427
+ self,
428
+ height: int = 256,
429
+ width: int = 256,
430
+ num_inference_steps: int = 50,
431
+ guidance_scale: float = 7.0,
432
+ prompt: str = "a car",
433
+ negative_prompt: str = "bad quality",
434
+ num_images_per_prompt: int = 1,
435
+ eta: float = 0.0,
436
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
437
+ output_type: Optional[str] = "pil",
438
+ return_dict: bool = True,
439
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
440
+ callback_steps: int = 1,
441
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
442
+ controlnet_conditioning_scale: float = 1.0,
443
+ ):
444
+ r"""
445
+ Function invoked when calling the pipeline for generation.
446
+
447
+ Args:
448
+ input_imgs (`PIL` or `List[PIL]`, *optional*):
449
+ The single input image for each 3D object
450
+ prompt_imgs (`PIL` or `List[PIL]`, *optional*):
451
+ Same as input_imgs, but will be used later as an image prompt condition, encoded by CLIP feature
452
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
453
+ The height in pixels of the generated image.
454
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
455
+ The width in pixels of the generated image.
456
+ num_inference_steps (`int`, *optional*, defaults to 50):
457
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
458
+ expense of slower inference.
459
+ guidance_scale (`float`, *optional*, defaults to 7.5):
460
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
461
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
462
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
463
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
464
+ usually at the expense of lower image quality.
465
+ negative_prompt (`str` or `List[str]`, *optional*):
466
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
467
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
468
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
469
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
470
+ The number of images to generate per prompt.
471
+ eta (`float`, *optional*, defaults to 0.0):
472
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
473
+ [`schedulers.DDIMScheduler`], will be ignored for others.
474
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
475
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
476
+ to make generation deterministic.
477
+ latents (`torch.FloatTensor`, *optional*):
478
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
479
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
480
+ tensor will ge generated by sampling using the supplied random `generator`.
481
+ prompt_embeds (`torch.FloatTensor`, *optional*):
482
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
483
+ provided, text embeddings will be generated from `prompt` input argument.
484
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
485
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
486
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
487
+ argument.
488
+ output_type (`str`, *optional*, defaults to `"pil"`):
489
+ The output format of the generate image. Choose between
490
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
491
+ return_dict (`bool`, *optional*, defaults to `True`):
492
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
493
+ plain tuple.
494
+ callback (`Callable`, *optional*):
495
+ A function that will be called every `callback_steps` steps during inference. The function will be
496
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
497
+ callback_steps (`int`, *optional*, defaults to 1):
498
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
499
+ called at every step.
500
+ cross_attention_kwargs (`dict`, *optional*):
501
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
502
+ `self.processor` in
503
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
504
+
505
+ Examples:
506
+
507
+ Returns:
508
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
509
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
510
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
511
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
512
+ (nsfw) content, according to the `safety_checker`.
513
+ """
514
+ # 0. Default height and width to unet
515
+ batch_size = 4
516
+ device = torch.device("cuda:0")
517
+
518
+ camera = get_camera(4).to(device=device)
519
+
520
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
521
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
522
+ # corresponds to doing no classifier free guidance.
523
+ do_classifier_free_guidance = guidance_scale > 1.0
524
+
525
+ # 4. Prepare timesteps
526
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
527
+ timesteps = self.scheduler.timesteps
528
+
529
+ prompt_embeds: torch.Tensor = self._encode_prompt(
530
+ prompt=prompt,
531
+ device=device,
532
+ num_images_per_prompt=num_images_per_prompt,
533
+ do_classifier_free_guidance=True,
534
+ negative_prompt=negative_prompt,
535
+ ) # type: ignore
536
+
537
+ # 5. Prepare latent variables
538
+ latents: torch.Tensor = self.prepare_latents(
539
+ batch_size * num_images_per_prompt,
540
+ 4,
541
+ height,
542
+ width,
543
+ prompt_embeds.dtype,
544
+ device,
545
+ generator,
546
+ None,
547
+ )
548
+
549
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
550
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
551
+
552
+ # 7. Denoising loop
553
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
554
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
555
+ for i, t in enumerate(timesteps):
556
+ # expand the latents if we are doing classifier free guidance
557
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
558
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
559
+
560
+ # predict the noise residual
561
+ prompt_embeds = torch.cat([prompt_embeds] * 4)
562
+ print(f"shape of latent_model_input: {latent_model_input.shape}") # [2*4, 4, 32, 32]
563
+ print(f"shape of prompt_embeds: {prompt_embeds.shape}") # [2*4, 77, 768]
564
+ print(f"shape of camera: {camera.shape}") # [4, 16]
565
+ noise_pred = self.unet.forward(x=latent_model_input, timesteps=torch.tensor([t], device=device), context=prompt_embeds, num_frames=4)
566
+
567
+ # perform guidance
568
+ if do_classifier_free_guidance:
569
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
570
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
571
+
572
+ # compute the previous noisy sample x_t -> x_t-1
573
+ # latents = self.scheduler.step(noise_pred.to(dtype=torch.float32), t, latents.to(dtype=torch.float32)).prev_sample.to(prompt_embeds.dtype)
574
+ latents: torch.Tensor = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
575
+
576
+ # call the callback, if provided
577
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
578
+ progress_bar.update()
579
+ if callback is not None and i % callback_steps == 0:
580
+ callback(i, t, latents) # type: ignore
581
+
582
+ # 8. Post-processing
583
+ if output_type == "latent":
584
+ image = latents
585
+ elif output_type == "pil":
586
+ # 8. Post-processing
587
+ image = self.decode_latents(latents)
588
+ # 10. Convert to PIL
589
+ image = self.numpy_to_pil(image)
590
+ else:
591
+ # 8. Post-processing
592
+ image = self.decode_latents(latents)
593
+
594
+ # Offload last model to CPU
595
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
596
+ self.final_offload_hook.offload()
597
+
598
+ if not return_dict:
599
+ return (image, None)
600
+
601
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
scripts/util.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import os
12
+ import math
13
+ import torch
14
+ import torch.nn as nn
15
+ import numpy as np
16
+ from einops import repeat
17
+ import importlib
18
+
19
+ def instantiate_from_config(config):
20
+ if not "target" in config:
21
+ if config == '__is_first_stage__':
22
+ return None
23
+ elif config == "__is_unconditional__":
24
+ return None
25
+ raise KeyError("Expected key `target` to instantiate.")
26
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
27
+
28
+
29
+ def get_obj_from_str(string, reload=False):
30
+ module, cls = string.rsplit(".", 1)
31
+ if reload:
32
+ module_imp = importlib.import_module(module)
33
+ importlib.reload(module_imp)
34
+ return getattr(importlib.import_module(module, package=None), cls)
35
+
36
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
37
+ if schedule == "linear":
38
+ betas = (
39
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
40
+ )
41
+
42
+ elif schedule == "cosine":
43
+ timesteps = (
44
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
45
+ )
46
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
47
+ alphas = torch.cos(alphas).pow(2)
48
+ alphas = alphas / alphas[0]
49
+ betas = 1 - alphas[1:] / alphas[:-1]
50
+ betas = np.clip(betas, a_min=0, a_max=0.999)
51
+
52
+ elif schedule == "sqrt_linear":
53
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
54
+ elif schedule == "sqrt":
55
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
56
+ else:
57
+ raise ValueError(f"schedule '{schedule}' unknown.")
58
+ return betas.numpy()
59
+
60
+
61
+ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
62
+ if ddim_discr_method == 'uniform':
63
+ c = num_ddpm_timesteps // num_ddim_timesteps
64
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
65
+ elif ddim_discr_method == 'quad':
66
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
67
+ else:
68
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
69
+
70
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
71
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
72
+ steps_out = ddim_timesteps + 1
73
+ if verbose:
74
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
75
+ return steps_out
76
+
77
+
78
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
79
+ # select alphas for computing the variance schedule
80
+ alphas = alphacums[ddim_timesteps]
81
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
82
+
83
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
84
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
85
+ if verbose:
86
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
87
+ print(f'For the chosen value of eta, which is {eta}, '
88
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
89
+ return sigmas, alphas, alphas_prev
90
+
91
+
92
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
93
+ """
94
+ Create a beta schedule that discretizes the given alpha_t_bar function,
95
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
96
+ :param num_diffusion_timesteps: the number of betas to produce.
97
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
98
+ produces the cumulative product of (1-beta) up to that
99
+ part of the diffusion process.
100
+ :param max_beta: the maximum beta to use; use values lower than 1 to
101
+ prevent singularities.
102
+ """
103
+ betas = []
104
+ for i in range(num_diffusion_timesteps):
105
+ t1 = i / num_diffusion_timesteps
106
+ t2 = (i + 1) / num_diffusion_timesteps
107
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
108
+ return np.array(betas)
109
+
110
+
111
+ def extract_into_tensor(a, t, x_shape):
112
+ b, *_ = t.shape
113
+ out = a.gather(-1, t)
114
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
115
+
116
+
117
+ def checkpoint(func, inputs, params, flag):
118
+ """
119
+ Evaluate a function without caching intermediate activations, allowing for
120
+ reduced memory at the expense of extra compute in the backward pass.
121
+ :param func: the function to evaluate.
122
+ :param inputs: the argument sequence to pass to `func`.
123
+ :param params: a sequence of parameters `func` depends on but does not
124
+ explicitly take as arguments.
125
+ :param flag: if False, disable gradient checkpointing.
126
+ """
127
+ if flag:
128
+ args = tuple(inputs) + tuple(params)
129
+ return CheckpointFunction.apply(func, len(inputs), *args)
130
+ else:
131
+ return func(*inputs)
132
+
133
+ class CheckpointFunction(torch.autograd.Function):
134
+ @staticmethod
135
+ def forward(ctx, run_function, length, *args):
136
+ ctx.run_function = run_function
137
+ ctx.input_tensors = list(args[:length])
138
+ ctx.input_params = list(args[length:])
139
+
140
+ with torch.no_grad():
141
+ output_tensors = ctx.run_function(*ctx.input_tensors)
142
+ return output_tensors
143
+
144
+ @staticmethod
145
+ def backward(ctx, *output_grads):
146
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
147
+ with torch.enable_grad():
148
+ # Fixes a bug where the first op in run_function modifies the
149
+ # Tensor storage in place, which is not allowed for detach()'d
150
+ # Tensors.
151
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
152
+ output_tensors = ctx.run_function(*shallow_copies)
153
+ input_grads = torch.autograd.grad(
154
+ output_tensors,
155
+ ctx.input_tensors + ctx.input_params,
156
+ output_grads,
157
+ allow_unused=True,
158
+ )
159
+ del ctx.input_tensors
160
+ del ctx.input_params
161
+ del output_tensors
162
+ return (None, None) + input_grads
163
+
164
+
165
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
166
+ """
167
+ Create sinusoidal timestep embeddings.
168
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
169
+ These may be fractional.
170
+ :param dim: the dimension of the output.
171
+ :param max_period: controls the minimum frequency of the embeddings.
172
+ :return: an [N x dim] Tensor of positional embeddings.
173
+ """
174
+ if not repeat_only:
175
+ half = dim // 2
176
+ freqs = torch.exp(
177
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
178
+ ).to(device=timesteps.device)
179
+ args = timesteps[:, None].float() * freqs[None]
180
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
181
+ if dim % 2:
182
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
183
+ else:
184
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
185
+ # import pdb; pdb.set_trace()
186
+ return embedding
187
+
188
+
189
+ def zero_module(module):
190
+ """
191
+ Zero out the parameters of a module and return it.
192
+ """
193
+ for p in module.parameters():
194
+ p.detach().zero_()
195
+ return module
196
+
197
+
198
+ def scale_module(module, scale):
199
+ """
200
+ Scale the parameters of a module and return it.
201
+ """
202
+ for p in module.parameters():
203
+ p.detach().mul_(scale)
204
+ return module
205
+
206
+
207
+ def mean_flat(tensor):
208
+ """
209
+ Take the mean over all non-batch dimensions.
210
+ """
211
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
212
+
213
+
214
+ def normalization(channels):
215
+ """
216
+ Make a standard normalization layer.
217
+ :param channels: number of input channels.
218
+ :return: an nn.Module for normalization.
219
+ """
220
+ return GroupNorm32(32, channels)
221
+
222
+
223
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
224
+ class SiLU(nn.Module):
225
+ def forward(self, x):
226
+ return x * torch.sigmoid(x)
227
+
228
+
229
+ class GroupNorm32(nn.GroupNorm):
230
+ def forward(self, x):
231
+ return super().forward(x.float()).type(x.dtype)
232
+
233
+ def conv_nd(dims, *args, **kwargs):
234
+ """
235
+ Create a 1D, 2D, or 3D convolution module.
236
+ """
237
+ if dims == 1:
238
+ return nn.Conv1d(*args, **kwargs)
239
+ elif dims == 2:
240
+ return nn.Conv2d(*args, **kwargs)
241
+ elif dims == 3:
242
+ return nn.Conv3d(*args, **kwargs)
243
+ raise ValueError(f"unsupported dimensions: {dims}")
244
+
245
+
246
+ def linear(*args, **kwargs):
247
+ """
248
+ Create a linear module.
249
+ """
250
+ return nn.Linear(*args, **kwargs)
251
+
252
+
253
+ def avg_pool_nd(dims, *args, **kwargs):
254
+ """
255
+ Create a 1D, 2D, or 3D average pooling module.
256
+ """
257
+ if dims == 1:
258
+ return nn.AvgPool1d(*args, **kwargs)
259
+ elif dims == 2:
260
+ return nn.AvgPool2d(*args, **kwargs)
261
+ elif dims == 3:
262
+ return nn.AvgPool3d(*args, **kwargs)
263
+ raise ValueError(f"unsupported dimensions: {dims}")
264
+
265
+
266
+ class HybridConditioner(nn.Module):
267
+
268
+ def __init__(self, c_concat_config, c_crossattn_config):
269
+ super().__init__()
270
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
271
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
272
+
273
+ def forward(self, c_concat, c_crossattn):
274
+ c_concat = self.concat_conditioner(c_concat)
275
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
276
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
277
+
278
+
279
+ def noise_like(shape, device, repeat=False):
280
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
281
+ noise = lambda: torch.randn(shape, device=device)
282
+ return repeat_noise() if repeat else noise()