yosepyossi commited on
Commit
f3d2df3
·
verified ·
1 Parent(s): b97ea5a

Upload folder using huggingface_hub

Browse files
.DS_Store ADDED
Binary file (8.2 kB). View file
 
feature_extractor/preprocessor_config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": {
3
+ "height": 224,
4
+ "width": 224
5
+ },
6
+ "do_center_crop": true,
7
+ "do_convert_rgb": true,
8
+ "do_normalize": true,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "image_mean": [
12
+ 0.48145466,
13
+ 0.4578275,
14
+ 0.40821073
15
+ ],
16
+ "image_processor_type": "CLIPImageProcessor",
17
+ "image_std": [
18
+ 0.26862954,
19
+ 0.26130258,
20
+ 0.27577711
21
+ ],
22
+ "resample": 3,
23
+ "rescale_factor": 0.00392156862745098,
24
+ "size": {
25
+ "shortest_edge": 224
26
+ }
27
+ }
image_encoder/.DS_Store ADDED
Binary file (6.15 kB). View file
 
image_encoder/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
3
+ "architectures": [
4
+ "CLIPVisionModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "dropout": 0.0,
8
+ "hidden_act": "gelu",
9
+ "hidden_size": 1280,
10
+ "image_size": 224,
11
+ "initializer_factor": 1.0,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 5120,
14
+ "layer_norm_eps": 1e-05,
15
+ "model_type": "clip_vision_model",
16
+ "num_attention_heads": 16,
17
+ "num_channels": 3,
18
+ "num_hidden_layers": 32,
19
+ "patch_size": 14,
20
+ "projection_dim": 1024,
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.46.2"
23
+ }
image_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8eb46f477ef5e1859b659014aed6ca56cdc207c12cb7a0f9d61b4d80a1a7bb84
3
+ size 2523128312
model_index.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "MVRAGPipeline",
3
+ "_diffusers_version": "0.25.0",
4
+ "feature_extractor": [
5
+ "transformers",
6
+ "CLIPImageProcessor"
7
+ ],
8
+ "image_encoder": [
9
+ "transformers",
10
+ "CLIPVisionModel"
11
+ ],
12
+ "resampler": [
13
+ "resampler",
14
+ "Resampler"
15
+ ],
16
+ "requires_safety_checker": false,
17
+ "scheduler": [
18
+ "diffusers",
19
+ "DDIMScheduler"
20
+ ],
21
+ "text_encoder": [
22
+ "transformers",
23
+ "CLIPTextModel"
24
+ ],
25
+ "tokenizer": [
26
+ "transformers",
27
+ "CLIPTokenizer"
28
+ ],
29
+ "unet": [
30
+ "mv_unet",
31
+ "MultiViewUNetModel"
32
+ ],
33
+ "vae": [
34
+ "diffusers",
35
+ "AutoencoderKL"
36
+ ]
37
+ }
resampler/config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "Resampler",
3
+ "_diffusers_version": "0.25.0",
4
+ "dim": 1024,
5
+ "depth": 8,
6
+ "dim_head": 64,
7
+ "heads": 12,
8
+ "num_queries": 16,
9
+ "embedding_dim": 1280,
10
+ "output_dim": 1024,
11
+ "ff_mult": 4
12
+ }
resampler/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:089e78f43f1f55ab598aecf5987ae1841c7b344f027bee0414e2c6df99e11c39
3
+ size 194171440
resampler/resampler.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from safetensors.torch import load_file
7
+
8
+
9
+ # FFN
10
+ def FeedForward(dim, mult=4):
11
+ inner_dim = int(dim * mult)
12
+ return nn.Sequential(
13
+ nn.LayerNorm(dim),
14
+ nn.Linear(dim, inner_dim, bias=False),
15
+ nn.GELU(),
16
+ nn.Linear(inner_dim, dim, bias=False),
17
+ )
18
+
19
+
20
+ def reshape_tensor(x, heads):
21
+ bs, length, width = x.shape
22
+ #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
23
+ x = x.view(bs, length, heads, -1)
24
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
25
+ x = x.transpose(1, 2)
26
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
27
+ x = x.reshape(bs, heads, length, -1)
28
+ return x
29
+
30
+
31
+ class PerceiverAttention(nn.Module):
32
+ def __init__(self, *, dim, dim_head=64, heads=8):
33
+ super().__init__()
34
+ self.scale = dim_head**-0.5
35
+ self.dim_head = dim_head
36
+ self.heads = heads
37
+ inner_dim = dim_head * heads
38
+
39
+ self.norm1 = nn.LayerNorm(dim)
40
+ self.norm2 = nn.LayerNorm(dim)
41
+
42
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
43
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
44
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
45
+
46
+
47
+ def forward(self, x, latents):
48
+ """
49
+ Args:
50
+ x (torch.Tensor): image features
51
+ shape (b, n1, D) [b, 257, 768] (after resampler.proj_in)
52
+ latent (torch.Tensor): latent features
53
+ shape (b, n2, D) [b, 16, 768]
54
+ """
55
+ x = self.norm1(x)
56
+ latents = self.norm2(latents)
57
+
58
+ b, l, _ = latents.shape
59
+
60
+ q = self.to_q(latents)
61
+ kv_input = torch.cat((x, latents), dim=-2)
62
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
63
+
64
+ q = reshape_tensor(q, self.heads)
65
+ k = reshape_tensor(k, self.heads)
66
+ v = reshape_tensor(v, self.heads)
67
+
68
+ # attention
69
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
70
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
71
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
72
+ out = weight @ v
73
+
74
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
75
+
76
+ return self.to_out(out)
77
+
78
+
79
+ class Resampler(nn.Module):
80
+ def __init__(
81
+ self,
82
+ dim=1024,
83
+ depth=4,
84
+ dim_head=64,
85
+ heads=12,
86
+ num_queries=16,
87
+ embedding_dim=1280,
88
+ output_dim=1024,
89
+ ff_mult=4,
90
+ ):
91
+ super().__init__()
92
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
93
+
94
+ self.proj_in = nn.Linear(embedding_dim, dim)
95
+
96
+ self.proj_out = nn.Linear(dim, output_dim)
97
+ self.norm_out = nn.LayerNorm(output_dim)
98
+
99
+ self.layers = nn.ModuleList([])
100
+ for _ in range(depth):
101
+ self.layers.append(
102
+ nn.ModuleList(
103
+ [
104
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
105
+ FeedForward(dim=dim, mult=ff_mult),
106
+ ]
107
+ )
108
+ )
109
+
110
+ def forward(self, x):
111
+ latents = self.latents.repeat(x.size(0), 1, 1)
112
+
113
+ x = self.proj_in(x)
114
+
115
+ for attn, ff in self.layers:
116
+ latents = attn(x, latents) + latents
117
+ latents = ff(latents) + latents
118
+
119
+ latents = self.proj_out(latents)
120
+ return self.norm_out(latents)
121
+
122
+ @classmethod
123
+ def from_pretrained(cls, pretrained_model_path, torch_dtype=None, **kwargs):
124
+ init_kwargs = {k: v for k, v in kwargs.items() if k in {
125
+ "dim", "depth", "dim_head", "heads", "num_queries",
126
+ "embedding_dim", "output_dim", "ff_mult"
127
+ }}
128
+ model = cls(**init_kwargs)
129
+ weights_path = f"{pretrained_model_path}/model.safetensors"
130
+ state_dict = load_file(weights_path)
131
+ model.load_state_dict(state_dict)
132
+ if torch_dtype is not None:
133
+ model = model.to(dtype=torch_dtype)
134
+ return model
135
+
136
+ @property
137
+ def dtype(self):
138
+ try:
139
+ dtype = next(self.parameters()).dtype
140
+ return dtype
141
+ except StopIteration:
142
+ return torch.float32
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DDIMScheduler",
3
+ "_diffusers_version": "0.25.0",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "clip_sample": false,
8
+ "clip_sample_range": 1.0,
9
+ "dynamic_thresholding_ratio": 0.995,
10
+ "num_train_timesteps": 1000,
11
+ "prediction_type": "epsilon",
12
+ "rescale_betas_zero_snr": false,
13
+ "sample_max_value": 1.0,
14
+ "set_alpha_to_one": false,
15
+ "steps_offset": 1,
16
+ "thresholding": false,
17
+ "timestep_spacing": "leading",
18
+ "trained_betas": null
19
+ }
text_encoder/.DS_Store ADDED
Binary file (6.15 kB). View file
 
text_encoder/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "stabilityai/stable-diffusion-2-1",
3
+ "architectures": [
4
+ "CLIPTextModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dropout": 0.0,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "gelu",
11
+ "hidden_size": 1024,
12
+ "initializer_factor": 1.0,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 4096,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 77,
17
+ "model_type": "clip_text_model",
18
+ "num_attention_heads": 16,
19
+ "num_hidden_layers": 23,
20
+ "pad_token_id": 1,
21
+ "projection_dim": 512,
22
+ "torch_dtype": "float16",
23
+ "transformers_version": "4.35.2",
24
+ "vocab_size": 49408
25
+ }
text_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc1827c465450322616f06dea41596eac7d493f4e95904dcb51f0fc745c4e13f
3
+ size 680820392
tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "!",
17
+ "unk_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "!",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "49406": {
13
+ "content": "<|startoftext|>",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "49407": {
21
+ "content": "<|endoftext|>",
22
+ "lstrip": false,
23
+ "normalized": true,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ }
28
+ },
29
+ "bos_token": "<|startoftext|>",
30
+ "clean_up_tokenization_spaces": true,
31
+ "do_lower_case": true,
32
+ "eos_token": "<|endoftext|>",
33
+ "errors": "replace",
34
+ "model_max_length": 77,
35
+ "pad_token": "!",
36
+ "tokenizer_class": "CLIPTokenizer",
37
+ "unk_token": "<|endoftext|>"
38
+ }
tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
unet/.DS_Store ADDED
Binary file (6.15 kB). View file
 
unet/config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "MultiViewUNetModel",
3
+ "_diffusers_version": "0.25.0",
4
+ "attention_resolutions": [
5
+ 4,
6
+ 2,
7
+ 1
8
+ ],
9
+ "camera_dim": 16,
10
+ "channel_mult": [
11
+ 1,
12
+ 2,
13
+ 4,
14
+ 4
15
+ ],
16
+ "context_dim": 1024,
17
+ "image_size": 32,
18
+ "in_channels": 4,
19
+ "model_channels": 320,
20
+ "num_head_channels": 64,
21
+ "num_res_blocks": 2,
22
+ "out_channels": 4,
23
+ "transformer_depth": 1
24
+ }
unet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0a6754790780176351e6fd0d5f082a1e14469740419ffa346869db3c0705a25
3
+ size 3226918912
unet/mv_unet.py ADDED
@@ -0,0 +1,827 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from inspect import isfunction
3
+ from typing import Optional, Any, List
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange, repeat
9
+
10
+ from diffusers.configuration_utils import ConfigMixin
11
+ from diffusers.models.modeling_utils import ModelMixin
12
+
13
+ # require xformers!
14
+ import xformers
15
+ import xformers.ops
16
+
17
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
18
+ """
19
+ Create sinusoidal timestep embeddings.
20
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
21
+ These may be fractional.
22
+ :param dim: the dimension of the output.
23
+ :param max_period: controls the minimum frequency of the embeddings.
24
+ :return: an [N x dim] Tensor of positional embeddings.
25
+ """
26
+ if not repeat_only:
27
+ half = dim // 2
28
+ freqs = torch.exp(
29
+ -math.log(max_period)
30
+ * torch.arange(start=0, end=half, dtype=torch.float32)
31
+ / half
32
+ ).to(device=timesteps.device)
33
+ args = timesteps[:, None] * freqs[None]
34
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
35
+ if dim % 2:
36
+ embedding = torch.cat(
37
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
38
+ )
39
+ else:
40
+ embedding = repeat(timesteps, "b -> b d", d=dim)
41
+ # import pdb; pdb.set_trace()
42
+ return embedding
43
+
44
+
45
+ def zero_module(module):
46
+ """
47
+ Zero out the parameters of a module and return it.
48
+ """
49
+ for p in module.parameters():
50
+ p.detach().zero_()
51
+ return module
52
+
53
+
54
+ def conv_nd(dims, *args, **kwargs):
55
+ """
56
+ Create a 1D, 2D, or 3D convolution module.
57
+ """
58
+ if dims == 1:
59
+ return nn.Conv1d(*args, **kwargs)
60
+ elif dims == 2:
61
+ return nn.Conv2d(*args, **kwargs)
62
+ elif dims == 3:
63
+ return nn.Conv3d(*args, **kwargs)
64
+ raise ValueError(f"unsupported dimensions: {dims}")
65
+
66
+
67
+ def avg_pool_nd(dims, *args, **kwargs):
68
+ """
69
+ Create a 1D, 2D, or 3D average pooling module.
70
+ """
71
+ if dims == 1:
72
+ return nn.AvgPool1d(*args, **kwargs)
73
+ elif dims == 2:
74
+ return nn.AvgPool2d(*args, **kwargs)
75
+ elif dims == 3:
76
+ return nn.AvgPool3d(*args, **kwargs)
77
+ raise ValueError(f"unsupported dimensions: {dims}")
78
+
79
+
80
+ def default(val, d):
81
+ if val is not None:
82
+ return val
83
+ return d() if isfunction(d) else d
84
+
85
+
86
+ class GEGLU(nn.Module):
87
+ def __init__(self, dim_in, dim_out):
88
+ super().__init__()
89
+ self.proj = nn.Linear(dim_in, dim_out * 2)
90
+
91
+ def forward(self, x):
92
+ x, gate = self.proj(x).chunk(2, dim=-1)
93
+ return x * F.gelu(gate)
94
+
95
+
96
+ class FeedForward(nn.Module):
97
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
98
+ super().__init__()
99
+ inner_dim = int(dim * mult)
100
+ dim_out = default(dim_out, dim)
101
+ project_in = (
102
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
103
+ if not glu
104
+ else GEGLU(dim, inner_dim)
105
+ )
106
+
107
+ self.net = nn.Sequential(
108
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
109
+ )
110
+
111
+ def forward(self, x):
112
+ return self.net(x)
113
+
114
+
115
+ class MemoryEfficientCrossAttention(nn.Module):
116
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
117
+ def __init__(
118
+ self,
119
+ query_dim,
120
+ context_dim=None,
121
+ heads=8,
122
+ dim_head=64,
123
+ dropout=0.0,
124
+ ip=False,
125
+ ):
126
+ super().__init__()
127
+
128
+ inner_dim = dim_head * heads
129
+ context_dim = default(context_dim, query_dim)
130
+
131
+ self.heads = heads
132
+ self.dim_head = dim_head
133
+
134
+ self.ip = ip
135
+ if self.ip:
136
+ self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
137
+ self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
138
+
139
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
140
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
141
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
142
+
143
+ self.to_out = nn.Sequential(
144
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
145
+ )
146
+ self.attention_op: Optional[Any] = None
147
+
148
+ def forward(self, x, context=None):
149
+ context_ip = None
150
+ q = self.to_q(x)
151
+ if context is not None:
152
+ context_ip = context['images_tokens']
153
+ scale = default(context['scale'], 1.0)
154
+ context = context['prompt']
155
+ context = default(context, x)
156
+
157
+ k = self.to_k(context)
158
+ v = self.to_v(context)
159
+
160
+ b, _, _ = q.shape
161
+ q, k, v = map(
162
+ lambda t: t.unsqueeze(3)
163
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
164
+ .contiguous(),
165
+ (q, k, v),
166
+ )
167
+ out = xformers.ops.memory_efficient_attention(
168
+ q, k, v, attn_bias=None, op=self.attention_op
169
+ )
170
+
171
+ if context_ip is not None:
172
+ k_ip = self.to_k_ip(context_ip)
173
+ v_ip = self.to_v_ip(context_ip)
174
+ k_ip, v_ip = map(
175
+ lambda t: t.unsqueeze(3)
176
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
177
+ .contiguous(),
178
+ (k_ip, v_ip),
179
+ )
180
+ out_ip = xformers.ops.memory_efficient_attention(
181
+ q, k_ip, v_ip, attn_bias=None, op=self.attention_op
182
+ )
183
+ out = scale * out + (1.5 - scale) * out_ip
184
+
185
+ out = out.reshape(b, out.shape[1], self.heads * self.dim_head)
186
+ return self.to_out(out)
187
+
188
+
189
+ class BasicTransformerBlock3D(nn.Module):
190
+
191
+ def __init__(
192
+ self,
193
+ dim,
194
+ n_heads,
195
+ d_head,
196
+ context_dim,
197
+ dropout=0.0,
198
+ gated_ff=True,
199
+ ):
200
+ super().__init__()
201
+
202
+ self.attn1 = MemoryEfficientCrossAttention(
203
+ query_dim=dim,
204
+ context_dim=None, # self-attention
205
+ heads=n_heads,
206
+ dim_head=d_head,
207
+ dropout=dropout,
208
+ )
209
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
210
+ self.attn2 = MemoryEfficientCrossAttention(
211
+ query_dim=dim,
212
+ context_dim=context_dim,
213
+ heads=n_heads,
214
+ dim_head=d_head,
215
+ dropout=dropout,
216
+ # ip only applies to cross-attention
217
+ ip=True,
218
+ )
219
+ self.norm1 = nn.LayerNorm(dim)
220
+ self.norm2 = nn.LayerNorm(dim)
221
+ self.norm3 = nn.LayerNorm(dim)
222
+
223
+ def forward(self, x, context=None, num_frames=1):
224
+ x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
225
+ x = self.attn1(self.norm1(x), context=None) + x
226
+ x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
227
+ x = self.attn2(self.norm2(x), context=context) + x
228
+ x = self.ff(self.norm3(x)) + x
229
+ return x
230
+
231
+
232
+ class SpatialTransformer3D(nn.Module):
233
+
234
+ def __init__(
235
+ self,
236
+ in_channels,
237
+ n_heads,
238
+ d_head,
239
+ context_dim, # cross attention input dim
240
+ depth=1,
241
+ dropout=0.0,
242
+ ):
243
+ super().__init__()
244
+
245
+ if not isinstance(context_dim, list):
246
+ context_dim = [context_dim]
247
+
248
+ self.in_channels = in_channels
249
+
250
+ inner_dim = n_heads * d_head
251
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
252
+ self.proj_in = nn.Linear(in_channels, inner_dim)
253
+
254
+ self.transformer_blocks = nn.ModuleList(
255
+ [
256
+ BasicTransformerBlock3D(
257
+ inner_dim,
258
+ n_heads,
259
+ d_head,
260
+ context_dim=context_dim[d],
261
+ dropout=dropout,
262
+ )
263
+ for d in range(depth)
264
+ ]
265
+ )
266
+
267
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
268
+
269
+
270
+ def forward(self, x, context=None, num_frames=1):
271
+ # note: if no context is given, cross-attention defaults to self-attention
272
+ if not isinstance(context, list):
273
+ context = [context]
274
+ b, c, h, w = x.shape
275
+ x_in = x
276
+ x = self.norm(x)
277
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
278
+ x = self.proj_in(x)
279
+ for i, block in enumerate(self.transformer_blocks):
280
+ x = block(x, context=context[i], num_frames=num_frames)
281
+ x = self.proj_out(x)
282
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
283
+
284
+ return x + x_in
285
+
286
+
287
+ class CondSequential(nn.Sequential):
288
+ """
289
+ A sequential module that passes timestep embeddings to the children that
290
+ support it as an extra input.
291
+ """
292
+ def forward(self, x, emb, context=None, num_frames=1):
293
+ for layer in self:
294
+ if isinstance(layer, ResBlock):
295
+ x = layer(x, emb)
296
+ elif isinstance(layer, SpatialTransformer3D):
297
+ x = layer(x, context, num_frames=num_frames)
298
+ else:
299
+ x = layer(x)
300
+ return x
301
+
302
+
303
+ class Upsample(nn.Module):
304
+ """
305
+ An upsampling layer with an optional convolution.
306
+ :param channels: channels in the inputs and outputs.
307
+ :param use_conv: a bool determining if a convolution is applied.
308
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
309
+ upsampling occurs in the inner-two dimensions.
310
+ """
311
+
312
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
313
+ super().__init__()
314
+ self.channels = channels
315
+ self.out_channels = out_channels or channels
316
+ self.use_conv = use_conv
317
+ self.dims = dims
318
+ if use_conv:
319
+ self.conv = conv_nd(
320
+ dims, self.channels, self.out_channels, 3, padding=padding
321
+ )
322
+
323
+ def forward(self, x):
324
+ assert x.shape[1] == self.channels
325
+ if self.dims == 3:
326
+ x = F.interpolate(
327
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
328
+ )
329
+ else:
330
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
331
+ if self.use_conv:
332
+ x = self.conv(x)
333
+ return x
334
+
335
+
336
+ class Downsample(nn.Module):
337
+ """
338
+ A downsampling layer with an optional convolution.
339
+ :param channels: channels in the inputs and outputs.
340
+ :param use_conv: a bool determining if a convolution is applied.
341
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
342
+ downsampling occurs in the inner-two dimensions.
343
+ """
344
+
345
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
346
+ super().__init__()
347
+ self.channels = channels
348
+ self.out_channels = out_channels or channels
349
+ self.use_conv = use_conv
350
+ self.dims = dims
351
+ stride = 2 if dims != 3 else (1, 2, 2)
352
+ if use_conv:
353
+ self.op = conv_nd(
354
+ dims,
355
+ self.channels,
356
+ self.out_channels,
357
+ 3,
358
+ stride=stride,
359
+ padding=padding,
360
+ )
361
+ else:
362
+ assert self.channels == self.out_channels
363
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
364
+
365
+ def forward(self, x):
366
+ assert x.shape[1] == self.channels
367
+ return self.op(x)
368
+
369
+
370
+ class ResBlock(nn.Module):
371
+ """
372
+ A residual block that can optionally change the number of channels.
373
+ :param channels: the number of input channels.
374
+ :param emb_channels: the number of timestep embedding channels.
375
+ :param dropout: the rate of dropout.
376
+ :param out_channels: if specified, the number of out channels.
377
+ :param use_conv: if True and out_channels is specified, use a spatial
378
+ convolution instead of a smaller 1x1 convolution to change the
379
+ channels in the skip connection.
380
+ :param dims: determines if the signal is 1D, 2D, or 3D.
381
+ :param up: if True, use this block for upsampling.
382
+ :param down: if True, use this block for downsampling.
383
+ """
384
+
385
+ def __init__(
386
+ self,
387
+ channels,
388
+ emb_channels,
389
+ dropout,
390
+ out_channels=None,
391
+ use_conv=False,
392
+ use_scale_shift_norm=False,
393
+ dims=2,
394
+ up=False,
395
+ down=False,
396
+ ):
397
+ super().__init__()
398
+ self.channels = channels
399
+ self.emb_channels = emb_channels
400
+ self.dropout = dropout
401
+ self.out_channels = out_channels or channels
402
+ self.use_conv = use_conv
403
+ self.use_scale_shift_norm = use_scale_shift_norm
404
+
405
+ self.in_layers = nn.Sequential(
406
+ nn.GroupNorm(32, channels),
407
+ nn.SiLU(),
408
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
409
+ )
410
+
411
+ self.updown = up or down
412
+
413
+ if up:
414
+ self.h_upd = Upsample(channels, False, dims)
415
+ self.x_upd = Upsample(channels, False, dims)
416
+ elif down:
417
+ self.h_upd = Downsample(channels, False, dims)
418
+ self.x_upd = Downsample(channels, False, dims)
419
+ else:
420
+ self.h_upd = self.x_upd = nn.Identity()
421
+
422
+ self.emb_layers = nn.Sequential(
423
+ nn.SiLU(),
424
+ nn.Linear(
425
+ emb_channels,
426
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
427
+ ),
428
+ )
429
+ self.out_layers = nn.Sequential(
430
+ nn.GroupNorm(32, self.out_channels),
431
+ nn.SiLU(),
432
+ nn.Dropout(p=dropout),
433
+ zero_module(
434
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
435
+ ),
436
+ )
437
+
438
+ if self.out_channels == channels:
439
+ self.skip_connection = nn.Identity()
440
+ elif use_conv:
441
+ self.skip_connection = conv_nd(
442
+ dims, channels, self.out_channels, 3, padding=1
443
+ )
444
+ else:
445
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
446
+
447
+ def forward(self, x, emb):
448
+ if self.updown:
449
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
450
+ h = in_rest(x)
451
+ h = self.h_upd(h)
452
+ x = self.x_upd(x)
453
+ h = in_conv(h)
454
+ else:
455
+ h = self.in_layers(x)
456
+ emb_out = self.emb_layers(emb).type(h.dtype)
457
+ while len(emb_out.shape) < len(h.shape):
458
+ emb_out = emb_out[..., None]
459
+ if self.use_scale_shift_norm:
460
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
461
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
462
+ h = out_norm(h) * (1 + scale) + shift
463
+ h = out_rest(h)
464
+ else:
465
+ h = h + emb_out
466
+ h = self.out_layers(h)
467
+ return self.skip_connection(x) + h
468
+
469
+
470
+ class MultiViewUNetModel(ModelMixin, ConfigMixin):
471
+ """
472
+ The full multi-view UNet model with attention, timestep embedding and camera embedding.
473
+ :param in_channels: channels in the input Tensor.
474
+ :param model_channels: base channel count for the model.
475
+ :param out_channels: channels in the output Tensor.
476
+ :param num_res_blocks: number of residual blocks per downsample.
477
+ :param attention_resolutions: a collection of downsample rates at which
478
+ attention will take place. May be a set, list, or tuple.
479
+ For example, if this contains 4, then at 4x downsampling, attention
480
+ will be used.
481
+ :param dropout: the dropout probability.
482
+ :param channel_mult: channel multiplier for each level of the UNet.
483
+ :param conv_resample: if True, use learned convolutions for upsampling and
484
+ downsampling.
485
+ :param dims: determines if the signal is 1D, 2D, or 3D.
486
+ :param num_classes: if specified (as an int), then this model will be
487
+ class-conditional with `num_classes` classes.
488
+ :param num_heads: the number of attention heads in each attention layer.
489
+ :param num_heads_channels: if specified, ignore num_heads and instead use
490
+ a fixed channel width per attention head.
491
+ :param num_heads_upsample: works with num_heads to set a different number
492
+ of heads for upsampling. Deprecated.
493
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
494
+ :param resblock_updown: use residual blocks for up/downsampling.
495
+ :param use_new_attention_order: use a different attention pattern for potentially
496
+ increased efficiency.
497
+ :param camera_dim: dimensionality of camera input.
498
+ """
499
+
500
+ def __init__(
501
+ self,
502
+ image_size,
503
+ in_channels,
504
+ model_channels,
505
+ out_channels,
506
+ num_res_blocks,
507
+ attention_resolutions,
508
+ dropout=0,
509
+ channel_mult=(1, 2, 4, 8),
510
+ conv_resample=True,
511
+ dims=2,
512
+ num_classes=None,
513
+ num_heads=-1,
514
+ num_head_channels=-1,
515
+ num_heads_upsample=-1,
516
+ use_scale_shift_norm=False,
517
+ resblock_updown=False,
518
+ transformer_depth=1,
519
+ context_dim=None,
520
+ n_embed=None,
521
+ num_attention_blocks=None,
522
+ adm_in_channels=None,
523
+ camera_dim=None,
524
+ ip_dim=0,
525
+ ip_weight=1.0,
526
+ **kwargs,
527
+ ):
528
+ super().__init__()
529
+ assert context_dim is not None
530
+
531
+ if num_heads_upsample == -1:
532
+ num_heads_upsample = num_heads
533
+
534
+ if num_heads == -1:
535
+ assert (
536
+ num_head_channels != -1
537
+ ), "Either num_heads or num_head_channels has to be set"
538
+
539
+ if num_head_channels == -1:
540
+ assert (
541
+ num_heads != -1
542
+ ), "Either num_heads or num_head_channels has to be set"
543
+
544
+ self.image_size = image_size
545
+ self.in_channels = in_channels
546
+ self.model_channels = model_channels
547
+ self.out_channels = out_channels
548
+ if isinstance(num_res_blocks, int):
549
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
550
+ else:
551
+ if len(num_res_blocks) != len(channel_mult):
552
+ raise ValueError(
553
+ "provide num_res_blocks either as an int (globally constant) or "
554
+ "as a list/tuple (per-level) with the same length as channel_mult"
555
+ )
556
+ self.num_res_blocks = num_res_blocks
557
+
558
+ if num_attention_blocks is not None:
559
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
560
+ assert all(
561
+ map(
562
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
563
+ range(len(num_attention_blocks)),
564
+ )
565
+ )
566
+ print(
567
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
568
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
569
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
570
+ f"attention will still not be set."
571
+ )
572
+
573
+ self.attention_resolutions = attention_resolutions
574
+ self.dropout = dropout
575
+ self.channel_mult = channel_mult
576
+ self.conv_resample = conv_resample
577
+ self.num_classes = num_classes
578
+ self.num_heads = num_heads
579
+ self.num_head_channels = num_head_channels
580
+ self.num_heads_upsample = num_heads_upsample
581
+ self.predict_codebook_ids = n_embed is not None
582
+
583
+ self.ip_dim = ip_dim
584
+ self.ip_weight = ip_weight
585
+
586
+ time_embed_dim = model_channels * 4
587
+ self.time_embed = nn.Sequential(
588
+ nn.Linear(model_channels, time_embed_dim),
589
+ nn.SiLU(),
590
+ nn.Linear(time_embed_dim, time_embed_dim),
591
+ )
592
+
593
+ if camera_dim is not None:
594
+ time_embed_dim = model_channels * 4
595
+ self.camera_embed = nn.Sequential(
596
+ nn.Linear(camera_dim, time_embed_dim),
597
+ nn.SiLU(),
598
+ nn.Linear(time_embed_dim, time_embed_dim),
599
+ )
600
+
601
+ if self.num_classes is not None:
602
+ if isinstance(self.num_classes, int):
603
+ self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
604
+ elif self.num_classes == "continuous":
605
+ # print("setting up linear c_adm embedding layer")
606
+ self.label_emb = nn.Linear(1, time_embed_dim)
607
+ elif self.num_classes == "sequential":
608
+ assert adm_in_channels is not None
609
+ self.label_emb = nn.Sequential(
610
+ nn.Sequential(
611
+ nn.Linear(adm_in_channels, time_embed_dim),
612
+ nn.SiLU(),
613
+ nn.Linear(time_embed_dim, time_embed_dim),
614
+ )
615
+ )
616
+ else:
617
+ raise ValueError()
618
+
619
+ self.input_blocks = nn.ModuleList(
620
+ [
621
+ CondSequential(
622
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
623
+ )
624
+ ]
625
+ )
626
+ self._feature_size = model_channels
627
+ input_block_chans = [model_channels]
628
+ ch = model_channels
629
+ ds = 1
630
+ for level, mult in enumerate(channel_mult):
631
+ for nr in range(self.num_res_blocks[level]):
632
+ layers: List[Any] = [
633
+ ResBlock(
634
+ ch,
635
+ time_embed_dim,
636
+ dropout,
637
+ out_channels=mult * model_channels,
638
+ dims=dims,
639
+ use_scale_shift_norm=use_scale_shift_norm,
640
+ )
641
+ ]
642
+ ch = mult * model_channels
643
+ if ds in attention_resolutions:
644
+ if num_head_channels == -1:
645
+ dim_head = ch // num_heads
646
+ else:
647
+ num_heads = ch // num_head_channels
648
+ dim_head = num_head_channels
649
+
650
+ if num_attention_blocks is None or nr < num_attention_blocks[level]:
651
+ layers.append(
652
+ SpatialTransformer3D(
653
+ ch,
654
+ num_heads,
655
+ dim_head,
656
+ context_dim=context_dim,
657
+ depth=transformer_depth,
658
+ )
659
+ )
660
+ self.input_blocks.append(CondSequential(*layers))
661
+ self._feature_size += ch
662
+ input_block_chans.append(ch)
663
+ if level != len(channel_mult) - 1:
664
+ out_ch = ch
665
+ self.input_blocks.append(
666
+ CondSequential(
667
+ ResBlock(
668
+ ch,
669
+ time_embed_dim,
670
+ dropout,
671
+ out_channels=out_ch,
672
+ dims=dims,
673
+ use_scale_shift_norm=use_scale_shift_norm,
674
+ down=True,
675
+ )
676
+ if resblock_updown
677
+ else Downsample(
678
+ ch, conv_resample, dims=dims, out_channels=out_ch
679
+ )
680
+ )
681
+ )
682
+ ch = out_ch
683
+ input_block_chans.append(ch)
684
+ ds *= 2
685
+ self._feature_size += ch
686
+
687
+ if num_head_channels == -1:
688
+ dim_head = ch // num_heads
689
+ else:
690
+ num_heads = ch // num_head_channels
691
+ dim_head = num_head_channels
692
+
693
+ self.middle_block = CondSequential(
694
+ ResBlock(
695
+ ch,
696
+ time_embed_dim,
697
+ dropout,
698
+ dims=dims,
699
+ use_scale_shift_norm=use_scale_shift_norm,
700
+ ),
701
+ SpatialTransformer3D(
702
+ ch,
703
+ num_heads,
704
+ dim_head,
705
+ context_dim=context_dim,
706
+ depth=transformer_depth,
707
+ ),
708
+ ResBlock(
709
+ ch,
710
+ time_embed_dim,
711
+ dropout,
712
+ dims=dims,
713
+ use_scale_shift_norm=use_scale_shift_norm,
714
+ ),
715
+ )
716
+ self._feature_size += ch
717
+
718
+ self.output_blocks = nn.ModuleList([])
719
+ for level, mult in list(enumerate(channel_mult))[::-1]:
720
+ for i in range(self.num_res_blocks[level] + 1):
721
+ ich = input_block_chans.pop()
722
+ layers = [
723
+ ResBlock(
724
+ ch + ich,
725
+ time_embed_dim,
726
+ dropout,
727
+ out_channels=model_channels * mult,
728
+ dims=dims,
729
+ use_scale_shift_norm=use_scale_shift_norm,
730
+ )
731
+ ]
732
+ ch = model_channels * mult
733
+ if ds in attention_resolutions:
734
+ if num_head_channels == -1:
735
+ dim_head = ch // num_heads
736
+ else:
737
+ num_heads = ch // num_head_channels
738
+ dim_head = num_head_channels
739
+
740
+ if num_attention_blocks is None or i < num_attention_blocks[level]:
741
+ layers.append(
742
+ SpatialTransformer3D(
743
+ ch,
744
+ num_heads,
745
+ dim_head,
746
+ context_dim=context_dim,
747
+ depth=transformer_depth,
748
+ )
749
+ )
750
+ if level and i == self.num_res_blocks[level]:
751
+ out_ch = ch
752
+ layers.append(
753
+ ResBlock(
754
+ ch,
755
+ time_embed_dim,
756
+ dropout,
757
+ out_channels=out_ch,
758
+ dims=dims,
759
+ use_scale_shift_norm=use_scale_shift_norm,
760
+ up=True,
761
+ )
762
+ if resblock_updown
763
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
764
+ )
765
+ ds //= 2
766
+ self.output_blocks.append(CondSequential(*layers))
767
+ self._feature_size += ch
768
+
769
+ self.out = nn.Sequential(
770
+ nn.GroupNorm(32, ch),
771
+ nn.SiLU(),
772
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
773
+ )
774
+ if self.predict_codebook_ids:
775
+ self.id_predictor = nn.Sequential(
776
+ nn.GroupNorm(32, ch),
777
+ conv_nd(dims, model_channels, n_embed, 1),
778
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
779
+ )
780
+
781
+ def forward(
782
+ self,
783
+ x,
784
+ timesteps=None,
785
+ context=None,
786
+ camera=None,
787
+ num_frames=1,
788
+ images_tokens=None,
789
+ scale=1.0,
790
+ **kwargs,
791
+ ):
792
+ """
793
+ Apply the model to an input batch.
794
+ :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
795
+ :param timesteps: a 1-D batch of timesteps.
796
+ :param context: conditioning plugged in via crossattn
797
+ :param num_frames: a integer indicating number of frames for tensor reshaping.
798
+ :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
799
+ """
800
+ assert (
801
+ x.shape[0] % num_frames == 0
802
+ ), "input batch size must be dividable by num_frames!"
803
+
804
+ hs = []
805
+
806
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
807
+
808
+ emb = self.time_embed(t_emb)
809
+
810
+ if camera is not None:
811
+ emb = emb + self.camera_embed(camera)
812
+
813
+ context = {'prompt': context, 'images_tokens': images_tokens, 'scale': scale}
814
+ h = x
815
+ for module in self.input_blocks:
816
+ h = module(h, emb, context, num_frames=num_frames)
817
+ hs.append(h)
818
+ h = self.middle_block(h, emb, context, num_frames=num_frames)
819
+ for module in self.output_blocks:
820
+ h = torch.cat([h, hs.pop()], dim=1)
821
+ h = module(h, emb, context, num_frames=num_frames)
822
+
823
+ h = h.type(x.dtype)
824
+ if self.predict_codebook_ids:
825
+ return self.id_predictor(h)
826
+ else:
827
+ return self.out(h)
vae/.DS_Store ADDED
Binary file (6.15 kB). View file
 
vae/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.25.0",
4
+ "act_fn": "silu",
5
+ "block_out_channels": [
6
+ 128,
7
+ 256,
8
+ 512,
9
+ 512
10
+ ],
11
+ "down_block_types": [
12
+ "DownEncoderBlock2D",
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D"
16
+ ],
17
+ "force_upcast": true,
18
+ "in_channels": 3,
19
+ "latent_channels": 4,
20
+ "layers_per_block": 2,
21
+ "norm_num_groups": 32,
22
+ "out_channels": 3,
23
+ "sample_size": 256,
24
+ "scaling_factor": 0.18215,
25
+ "up_block_types": [
26
+ "UpDecoderBlock2D",
27
+ "UpDecoderBlock2D",
28
+ "UpDecoderBlock2D",
29
+ "UpDecoderBlock2D"
30
+ ]
31
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e4c08995484ee61270175e9e7a072b66a6e4eeb5f0c266667fe1f45b90daf9a
3
+ size 167335342