YUGOROU commited on
Commit
d70b00d
·
verified ·
1 Parent(s): 91020fc

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +12 -0
  2. .gitignore +14 -0
  3. 2D_Stage/configs/infer.yaml +24 -0
  4. 2D_Stage/material/examples/1.png +3 -0
  5. 2D_Stage/material/examples/2.png +3 -0
  6. 2D_Stage/material/examples/3.png +3 -0
  7. 2D_Stage/material/examples/4.png +3 -0
  8. 2D_Stage/material/examples/5.png +3 -0
  9. 2D_Stage/material/examples/6.png +0 -0
  10. 2D_Stage/material/examples/7.png +3 -0
  11. 2D_Stage/material/examples/8.png +3 -0
  12. 2D_Stage/material/pose.json +38 -0
  13. 2D_Stage/material/pose0.png +0 -0
  14. 2D_Stage/material/pose1.png +0 -0
  15. 2D_Stage/material/pose2.png +0 -0
  16. 2D_Stage/material/pose3.png +0 -0
  17. 2D_Stage/tuneavideo/models/PoseGuider.py +59 -0
  18. 2D_Stage/tuneavideo/models/attention.py +344 -0
  19. 2D_Stage/tuneavideo/models/imageproj.py +118 -0
  20. 2D_Stage/tuneavideo/models/refunet.py +125 -0
  21. 2D_Stage/tuneavideo/models/resnet.py +210 -0
  22. 2D_Stage/tuneavideo/models/transformer_mv2d.py +1010 -0
  23. 2D_Stage/tuneavideo/models/unet.py +497 -0
  24. 2D_Stage/tuneavideo/models/unet_blocks.py +596 -0
  25. 2D_Stage/tuneavideo/models/unet_mv2d_blocks.py +926 -0
  26. 2D_Stage/tuneavideo/models/unet_mv2d_condition.py +1509 -0
  27. 2D_Stage/tuneavideo/models/unet_mv2d_ref.py +1570 -0
  28. 2D_Stage/tuneavideo/pipelines/pipeline_tuneavideo.py +585 -0
  29. 2D_Stage/tuneavideo/util.py +128 -0
  30. 2D_Stage/webui.py +323 -0
  31. 3D_Stage/configs/infer.yaml +104 -0
  32. 3D_Stage/load/tets/generate_tets.py +58 -0
  33. 3D_Stage/lrm/__init__.py +29 -0
  34. 3D_Stage/lrm/models/__init__.py +0 -0
  35. 3D_Stage/lrm/models/background/__init__.py +0 -0
  36. 3D_Stage/lrm/models/background/base.py +24 -0
  37. 3D_Stage/lrm/models/background/solid_color_background.py +58 -0
  38. 3D_Stage/lrm/models/camera.py +33 -0
  39. 3D_Stage/lrm/models/exporters/__init__.py +0 -0
  40. 3D_Stage/lrm/models/exporters/base.py +33 -0
  41. 3D_Stage/lrm/models/exporters/mesh_exporter.py +263 -0
  42. 3D_Stage/lrm/models/isosurface.py +272 -0
  43. 3D_Stage/lrm/models/lpips.py +20 -0
  44. 3D_Stage/lrm/models/materials/__init__.py +0 -0
  45. 3D_Stage/lrm/models/materials/base.py +29 -0
  46. 3D_Stage/lrm/models/materials/no_material.py +60 -0
  47. 3D_Stage/lrm/models/mesh.py +471 -0
  48. 3D_Stage/lrm/models/networks.py +390 -0
  49. 3D_Stage/lrm/models/renderers/__init__.py +0 -0
  50. 3D_Stage/lrm/models/renderers/base.py +68 -0
.gitattributes CHANGED
@@ -33,3 +33,15 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ 2D_Stage/material/examples/1.png filter=lfs diff=lfs merge=lfs -text
37
+ 2D_Stage/material/examples/2.png filter=lfs diff=lfs merge=lfs -text
38
+ 2D_Stage/material/examples/3.png filter=lfs diff=lfs merge=lfs -text
39
+ 2D_Stage/material/examples/4.png filter=lfs diff=lfs merge=lfs -text
40
+ 2D_Stage/material/examples/5.png filter=lfs diff=lfs merge=lfs -text
41
+ 2D_Stage/material/examples/7.png filter=lfs diff=lfs merge=lfs -text
42
+ 2D_Stage/material/examples/8.png filter=lfs diff=lfs merge=lfs -text
43
+ 3D_Stage/material/examples/1/1.png filter=lfs diff=lfs merge=lfs -text
44
+ 3D_Stage/material/examples/1/2.png filter=lfs diff=lfs merge=lfs -text
45
+ 3D_Stage/material/examples/1/3.png filter=lfs diff=lfs merge=lfs -text
46
+ 3D_Stage/material/examples/1/4.png filter=lfs diff=lfs merge=lfs -text
47
+ final_texture.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore Python bytecode files
2
+ *.pyc
3
+ *.pyo
4
+ __pycache__/
5
+
6
+ # Ignore virtual environment directory
7
+ venv/
8
+
9
+ /3D_stage/outputs/
10
+ input_3D.png
11
+ input.png
12
+
13
+ # LFS pointer files (large model files downloaded at runtime)
14
+ 3D_Stage/load/tets/*.npz
2D_Stage/configs/infer.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_path: "sd2-community/stable-diffusion-2-1"
2
+ image_encoder_path: "./models/image_encoder"
3
+ ckpt_dir: "./models/checkpoint"
4
+
5
+ validation:
6
+ guidance_scale: 5.0
7
+ use_inv_latent: False
8
+ video_length: 4
9
+
10
+ use_pose_guider: True
11
+ use_noise: False
12
+ use_shifted_noise: False
13
+ unet_condition_type: image
14
+
15
+ unet_from_pretrained_kwargs:
16
+ camera_embedding_type: 'e_de_da_sincos'
17
+ projection_class_embeddings_input_dim: 10 # modify
18
+ joint_attention: false # modify
19
+ num_views: 4
20
+ sample_size: 96
21
+ zero_init_conv_in: false
22
+ zero_init_camera_projection: false
23
+ in_channels: 4
24
+ use_safetensors: true
2D_Stage/material/examples/1.png ADDED

Git LFS Details

  • SHA256: f8fd677efa043cc71fbe0d78e30b93c1f49fd88fd8d2a00ae6946f6f0a06d3b2
  • Pointer size: 131 Bytes
  • Size of remote file: 621 kB
2D_Stage/material/examples/2.png ADDED

Git LFS Details

  • SHA256: b342e025d7170708fdc19c1fa816835d70de61a315d6bde3e2afb390977882d0
  • Pointer size: 131 Bytes
  • Size of remote file: 305 kB
2D_Stage/material/examples/3.png ADDED

Git LFS Details

  • SHA256: cf15c6d6321c8ae8177aefd6ae1d3c41b3cbd51f1f0df13f1be53633273fa457
  • Pointer size: 131 Bytes
  • Size of remote file: 188 kB
2D_Stage/material/examples/4.png ADDED

Git LFS Details

  • SHA256: 7291130686044453c3c1a7fc7c9b66bb2424a825cf2406149d6a193cdcf7638c
  • Pointer size: 131 Bytes
  • Size of remote file: 179 kB
2D_Stage/material/examples/5.png ADDED

Git LFS Details

  • SHA256: 385007de9cb47cafd6f63971afaf72e7d8ab6f7c146dc10cf0ad4788072d592c
  • Pointer size: 131 Bytes
  • Size of remote file: 142 kB
2D_Stage/material/examples/6.png ADDED
2D_Stage/material/examples/7.png ADDED

Git LFS Details

  • SHA256: 1b8e002752dbae219db8c7f6ecab6b79d62c31c06a5a36520effe82eb53418a4
  • Pointer size: 131 Bytes
  • Size of remote file: 167 kB
2D_Stage/material/examples/8.png ADDED

Git LFS Details

  • SHA256: fd3def91bb231c82fb39a770a8b645078ab47d4a8f74e2068ccb8db0a0438837
  • Pointer size: 131 Bytes
  • Size of remote file: 198 kB
2D_Stage/material/pose.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ [
3
+ [
4
+ 0, 0, -1, 0,
5
+ 0, 1, 0, 0,
6
+ 1, 0, 0, 0,
7
+ 1.5, 0, 0, 1
8
+ ],
9
+ "pose0.png"
10
+ ],
11
+ [
12
+ [
13
+ 0, 0, 1, 0,
14
+ 0, 1, 0, 0,
15
+ -1, 0, 0, 0,
16
+ -1.5, 0, 0, 1
17
+ ],
18
+ "pose1.png"
19
+ ],
20
+ [
21
+ [
22
+ 0, 0, 1, 0,
23
+ 0, 1, 0, 0,
24
+ -1, 0, 0, 0,
25
+ -1.5, 0, 0, 1
26
+ ],
27
+ "pose2.png"
28
+ ],
29
+ [
30
+ [
31
+ -1, 0, 0, 0,
32
+ 0, 1, 0, 0,
33
+ 0, 0, -1, 0,
34
+ 0, 0, -1.5, 1
35
+ ],
36
+ "pose3.png"
37
+ ]
38
+ ]
2D_Stage/material/pose0.png ADDED
2D_Stage/material/pose1.png ADDED
2D_Stage/material/pose2.png ADDED
2D_Stage/material/pose3.png ADDED
2D_Stage/tuneavideo/models/PoseGuider.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.init as init
5
+ from einops import rearrange
6
+
7
+ class PoseGuider(nn.Module):
8
+ def __init__(self, noise_latent_channels=4):
9
+ super(PoseGuider, self).__init__()
10
+
11
+ self.conv_layers = nn.Sequential(
12
+ nn.Conv2d(in_channels=3, out_channels=16, kernel_size=4, stride=2, padding=1),
13
+ nn.ReLU(),
14
+ nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2, padding=1),
15
+ nn.ReLU(),
16
+ nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1),
17
+ nn.ReLU(),
18
+ nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
19
+ nn.ReLU()
20
+ )
21
+
22
+ # Final projection layer
23
+ self.final_proj = nn.Conv2d(in_channels=128, out_channels=noise_latent_channels, kernel_size=1)
24
+
25
+ # Initialize layers
26
+ self._initialize_weights()
27
+
28
+ def _initialize_weights(self):
29
+ # Initialize weights with Gaussian distribution and zero out the final layer
30
+ for m in self.conv_layers:
31
+ if isinstance(m, nn.Conv2d):
32
+ init.normal_(m.weight, mean=0.0, std=0.02)
33
+ if m.bias is not None:
34
+ init.zeros_(m.bias)
35
+
36
+ init.zeros_(self.final_proj.weight)
37
+ if self.final_proj.bias is not None:
38
+ init.zeros_(self.final_proj.bias)
39
+
40
+ def forward(self, pose_image):
41
+ x = self.conv_layers(pose_image)
42
+ x = self.final_proj(x)
43
+
44
+ return x
45
+
46
+ @classmethod
47
+ def from_pretrained(pretrained_model_path):
48
+ if not os.path.exists(pretrained_model_path):
49
+ print(f"There is no model file in {pretrained_model_path}")
50
+ print(f"loaded PoseGuider's pretrained weights from {pretrained_model_path} ...")
51
+
52
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
53
+ model = PoseGuider(noise_latent_channels=4)
54
+ m, u = model.load_state_dict(state_dict, strict=False)
55
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
56
+ params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()]
57
+ print(f"### PoseGuider's Parameters: {sum(params) / 1e6} M")
58
+
59
+ return model
2D_Stage/tuneavideo/models/attention.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
11
+ from diffusers import ModelMixin
12
+ from diffusers.utils import BaseOutput
13
+ from diffusers.utils.import_utils import is_xformers_available
14
+ from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm
15
+
16
+ from einops import rearrange, repeat
17
+
18
+
19
+ @dataclass
20
+ class Transformer3DModelOutput(BaseOutput):
21
+ sample: torch.FloatTensor
22
+
23
+
24
+ if is_xformers_available():
25
+ import xformers
26
+ import xformers.ops
27
+ else:
28
+ xformers = None
29
+
30
+
31
+ class Transformer3DModel(ModelMixin, ConfigMixin):
32
+ @register_to_config
33
+ def __init__(
34
+ self,
35
+ num_attention_heads: int = 16,
36
+ attention_head_dim: int = 88,
37
+ in_channels: Optional[int] = None,
38
+ num_layers: int = 1,
39
+ dropout: float = 0.0,
40
+ norm_num_groups: int = 32,
41
+ cross_attention_dim: Optional[int] = None,
42
+ attention_bias: bool = False,
43
+ activation_fn: str = "geglu",
44
+ num_embeds_ada_norm: Optional[int] = None,
45
+ use_linear_projection: bool = False,
46
+ only_cross_attention: bool = False,
47
+ upcast_attention: bool = False,
48
+ use_attn_temp: bool = False,
49
+ ):
50
+ super().__init__()
51
+ self.use_linear_projection = use_linear_projection
52
+ self.num_attention_heads = num_attention_heads
53
+ self.attention_head_dim = attention_head_dim
54
+ inner_dim = num_attention_heads * attention_head_dim
55
+
56
+ # Define input layers
57
+ self.in_channels = in_channels
58
+
59
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
60
+ if use_linear_projection:
61
+ self.proj_in = nn.Linear(in_channels, inner_dim)
62
+ else:
63
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
64
+
65
+ # Define transformers blocks
66
+ self.transformer_blocks = nn.ModuleList(
67
+ [
68
+ BasicTransformerBlock(
69
+ inner_dim,
70
+ num_attention_heads,
71
+ attention_head_dim,
72
+ dropout=dropout,
73
+ cross_attention_dim=cross_attention_dim,
74
+ activation_fn=activation_fn,
75
+ num_embeds_ada_norm=num_embeds_ada_norm,
76
+ attention_bias=attention_bias,
77
+ only_cross_attention=only_cross_attention,
78
+ upcast_attention=upcast_attention,
79
+ use_attn_temp = use_attn_temp,
80
+ )
81
+ for d in range(num_layers)
82
+ ]
83
+ )
84
+
85
+ # 4. Define output layers
86
+ if use_linear_projection:
87
+ self.proj_out = nn.Linear(in_channels, inner_dim)
88
+ else:
89
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
90
+
91
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
92
+ # Input
93
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
94
+ video_length = hidden_states.shape[2]
95
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
96
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
97
+
98
+ batch, channel, height, weight = hidden_states.shape
99
+ residual = hidden_states
100
+
101
+ hidden_states = self.norm(hidden_states)
102
+ if not self.use_linear_projection:
103
+ hidden_states = self.proj_in(hidden_states)
104
+ inner_dim = hidden_states.shape[1]
105
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
106
+ else:
107
+ inner_dim = hidden_states.shape[1]
108
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
109
+ hidden_states = self.proj_in(hidden_states)
110
+
111
+ # Blocks
112
+ for block in self.transformer_blocks:
113
+ hidden_states = block(
114
+ hidden_states,
115
+ encoder_hidden_states=encoder_hidden_states,
116
+ timestep=timestep,
117
+ video_length=video_length
118
+ )
119
+
120
+ # Output
121
+ if not self.use_linear_projection:
122
+ hidden_states = (
123
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
124
+ )
125
+ hidden_states = self.proj_out(hidden_states)
126
+ else:
127
+ hidden_states = self.proj_out(hidden_states)
128
+ hidden_states = (
129
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
130
+ )
131
+
132
+ output = hidden_states + residual
133
+
134
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
135
+ if not return_dict:
136
+ return (output,)
137
+
138
+ return Transformer3DModelOutput(sample=output)
139
+
140
+
141
+ class BasicTransformerBlock(nn.Module):
142
+ def __init__(
143
+ self,
144
+ dim: int,
145
+ num_attention_heads: int,
146
+ attention_head_dim: int,
147
+ dropout=0.0,
148
+ cross_attention_dim: Optional[int] = None,
149
+ activation_fn: str = "geglu",
150
+ num_embeds_ada_norm: Optional[int] = None,
151
+ attention_bias: bool = False,
152
+ only_cross_attention: bool = False,
153
+ upcast_attention: bool = False,
154
+ use_attn_temp: bool = False
155
+ ):
156
+ super().__init__()
157
+ self.only_cross_attention = only_cross_attention
158
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
159
+ self.use_attn_temp = use_attn_temp
160
+ # SC-Attn
161
+ self.attn1 = SparseCausalAttention(
162
+ query_dim=dim,
163
+ heads=num_attention_heads,
164
+ dim_head=attention_head_dim,
165
+ dropout=dropout,
166
+ bias=attention_bias,
167
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
168
+ upcast_attention=upcast_attention,
169
+ )
170
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
171
+
172
+ # Cross-Attn
173
+ if cross_attention_dim is not None:
174
+ self.attn2 = CrossAttention(
175
+ query_dim=dim,
176
+ cross_attention_dim=cross_attention_dim,
177
+ heads=num_attention_heads,
178
+ dim_head=attention_head_dim,
179
+ dropout=dropout,
180
+ bias=attention_bias,
181
+ upcast_attention=upcast_attention,
182
+ )
183
+ else:
184
+ self.attn2 = None
185
+
186
+ if cross_attention_dim is not None:
187
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
188
+ else:
189
+ self.norm2 = None
190
+
191
+ # Feed-forward
192
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
193
+ self.norm3 = nn.LayerNorm(dim)
194
+
195
+ # Temp-Attn
196
+ if self.use_attn_temp:
197
+ self.attn_temp = CrossAttention(
198
+ query_dim=dim,
199
+ heads=num_attention_heads,
200
+ dim_head=attention_head_dim,
201
+ dropout=dropout,
202
+ bias=attention_bias,
203
+ upcast_attention=upcast_attention,
204
+ )
205
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
206
+ self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
207
+
208
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
209
+ if not is_xformers_available():
210
+ print("Here is how to install it")
211
+ raise ModuleNotFoundError(
212
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
213
+ " xformers",
214
+ name="xformers",
215
+ )
216
+ elif not torch.cuda.is_available():
217
+ raise ValueError(
218
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
219
+ " available for GPU "
220
+ )
221
+ else:
222
+ try:
223
+ # Make sure we can run the memory efficient attention
224
+ _ = xformers.ops.memory_efficient_attention(
225
+ torch.randn((1, 2, 40), device="cuda"),
226
+ torch.randn((1, 2, 40), device="cuda"),
227
+ torch.randn((1, 2, 40), device="cuda"),
228
+ )
229
+ except Exception as e:
230
+ raise e
231
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
232
+ if self.attn2 is not None:
233
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
234
+ #self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
235
+
236
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
237
+ # SparseCausal-Attention
238
+ norm_hidden_states = (
239
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
240
+ )
241
+
242
+ if self.only_cross_attention:
243
+ hidden_states = (
244
+ self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
245
+ )
246
+ else:
247
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
248
+
249
+ if self.attn2 is not None:
250
+ # Cross-Attention
251
+ norm_hidden_states = (
252
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
253
+ )
254
+ hidden_states = (
255
+ self.attn2(
256
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
257
+ )
258
+ + hidden_states
259
+ )
260
+
261
+ # Feed-forward
262
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
263
+
264
+ # Temporal-Attention
265
+ if self.use_attn_temp:
266
+ d = hidden_states.shape[1]
267
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
268
+ norm_hidden_states = (
269
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
270
+ )
271
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
272
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
273
+
274
+ return hidden_states
275
+
276
+
277
+ class SparseCausalAttention(CrossAttention):
278
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, use_full_attn=True):
279
+ batch_size, sequence_length, _ = hidden_states.shape
280
+
281
+ encoder_hidden_states = encoder_hidden_states
282
+
283
+ if self.group_norm is not None:
284
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
285
+
286
+ query = self.to_q(hidden_states)
287
+ # query = rearrange(query, "(b f) d c -> b (f d) c", f=video_length)
288
+ dim = query.shape[-1]
289
+ query = self.reshape_heads_to_batch_dim(query)
290
+
291
+ if self.added_kv_proj_dim is not None:
292
+ raise NotImplementedError
293
+
294
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
295
+ key = self.to_k(encoder_hidden_states)
296
+ value = self.to_v(encoder_hidden_states)
297
+
298
+ former_frame_index = torch.arange(video_length) - 1
299
+ former_frame_index[0] = 0
300
+
301
+ key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
302
+ if not use_full_attn:
303
+ key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2)
304
+ else:
305
+ # key = torch.cat([key[:, [0] * video_length], key[:, [1] * video_length], key[:, [2] * video_length], key[:, [3] * video_length]], dim=2)
306
+ key_video_length = [key[:, [i] * video_length] for i in range(video_length)]
307
+ key = torch.cat(key_video_length, dim=2)
308
+ key = rearrange(key, "b f d c -> (b f) d c")
309
+
310
+ value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
311
+ if not use_full_attn:
312
+ value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2)
313
+ else:
314
+ # value = torch.cat([value[:, [0] * video_length], value[:, [1] * video_length], value[:, [2] * video_length], value[:, [3] * video_length]], dim=2)
315
+ value_video_length = [value[:, [i] * video_length] for i in range(video_length)]
316
+ value = torch.cat(value_video_length, dim=2)
317
+ value = rearrange(value, "b f d c -> (b f) d c")
318
+
319
+ key = self.reshape_heads_to_batch_dim(key)
320
+ value = self.reshape_heads_to_batch_dim(value)
321
+
322
+ if attention_mask is not None:
323
+ if attention_mask.shape[-1] != query.shape[1]:
324
+ target_length = query.shape[1]
325
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
326
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
327
+
328
+ # attention, what we cannot get enough of
329
+ if self._use_memory_efficient_attention_xformers:
330
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
331
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
332
+ hidden_states = hidden_states.to(query.dtype)
333
+ else:
334
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
335
+ hidden_states = self._attention(query, key, value, attention_mask)
336
+ else:
337
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
338
+
339
+ # linear proj
340
+ hidden_states = self.to_out[0](hidden_states)
341
+
342
+ # dropout
343
+ hidden_states = self.to_out[1](hidden_states)
344
+ return hidden_states
2D_Stage/tuneavideo/models/imageproj.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ # FFN
8
+ def FeedForward(dim, mult=4):
9
+ inner_dim = int(dim * mult)
10
+ return nn.Sequential(
11
+ nn.LayerNorm(dim),
12
+ nn.Linear(dim, inner_dim, bias=False),
13
+ nn.GELU(),
14
+ nn.Linear(inner_dim, dim, bias=False),
15
+ )
16
+
17
+ def reshape_tensor(x, heads):
18
+ bs, length, width = x.shape
19
+ #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
20
+ x = x.view(bs, length, heads, -1)
21
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
22
+ x = x.transpose(1, 2)
23
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
24
+ x = x.reshape(bs, heads, length, -1)
25
+ return x
26
+
27
+
28
+ class PerceiverAttention(nn.Module):
29
+ def __init__(self, *, dim, dim_head=64, heads=8):
30
+ super().__init__()
31
+ self.scale = dim_head**-0.5
32
+ self.dim_head = dim_head
33
+ self.heads = heads
34
+ inner_dim = dim_head * heads
35
+
36
+ self.norm1 = nn.LayerNorm(dim)
37
+ self.norm2 = nn.LayerNorm(dim)
38
+
39
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
40
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
41
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
42
+
43
+
44
+ def forward(self, x, latents):
45
+ """
46
+ Args:
47
+ x (torch.Tensor): image features
48
+ shape (b, n1, D)
49
+ latent (torch.Tensor): latent features
50
+ shape (b, n2, D)
51
+ """
52
+ x = self.norm1(x)
53
+ latents = self.norm2(latents)
54
+
55
+ b, l, _ = latents.shape
56
+
57
+ q = self.to_q(latents)
58
+ kv_input = torch.cat((x, latents), dim=-2)
59
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
60
+
61
+ q = reshape_tensor(q, self.heads)
62
+ k = reshape_tensor(k, self.heads)
63
+ v = reshape_tensor(v, self.heads)
64
+
65
+ # attention
66
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
67
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
68
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
69
+ out = weight @ v
70
+
71
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
72
+
73
+ return self.to_out(out)
74
+
75
+ class Resampler(nn.Module):
76
+ def __init__(
77
+ self,
78
+ dim=1024,
79
+ depth=8,
80
+ dim_head=64,
81
+ heads=16,
82
+ num_queries=8,
83
+ embedding_dim=768,
84
+ output_dim=1024,
85
+ ff_mult=4,
86
+ ):
87
+ super().__init__()
88
+
89
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
90
+
91
+ self.proj_in = nn.Linear(embedding_dim, dim)
92
+
93
+ self.proj_out = nn.Linear(dim, output_dim)
94
+ self.norm_out = nn.LayerNorm(output_dim)
95
+
96
+ self.layers = nn.ModuleList([])
97
+ for _ in range(depth):
98
+ self.layers.append(
99
+ nn.ModuleList(
100
+ [
101
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
102
+ FeedForward(dim=dim, mult=ff_mult),
103
+ ]
104
+ )
105
+ )
106
+
107
+ def forward(self, x):
108
+
109
+ latents = self.latents.repeat(x.size(0), 1, 1)
110
+
111
+ x = self.proj_in(x)
112
+
113
+ for attn, ff in self.layers:
114
+ latents = attn(x, latents) + latents
115
+ latents = ff(latents) + latents
116
+
117
+ latents = self.proj_out(latents)
118
+ return self.norm_out(latents)
2D_Stage/tuneavideo/models/refunet.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from typing import Any, Dict, Optional
4
+ from diffusers.utils.import_utils import is_xformers_available
5
+ from tuneavideo.models.transformer_mv2d import XFormersMVAttnProcessor, MVAttnProcessor
6
+ class ReferenceOnlyAttnProc(torch.nn.Module):
7
+ def __init__(
8
+ self,
9
+ chained_proc,
10
+ enabled=False,
11
+ name=None
12
+ ) -> None:
13
+ super().__init__()
14
+ self.enabled = enabled
15
+ self.chained_proc = chained_proc
16
+ self.name = name
17
+
18
+ def __call__(
19
+ self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None,
20
+ mode="w", ref_dict: dict = None, is_cfg_guidance = False,num_views=4,
21
+ multiview_attention=True,
22
+ cross_domain_attention=False,
23
+ ) -> Any:
24
+ if encoder_hidden_states is None:
25
+ encoder_hidden_states = hidden_states
26
+ # print(self.enabled)
27
+ if self.enabled:
28
+ if mode == 'w':
29
+ ref_dict[self.name] = encoder_hidden_states
30
+ res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, num_views=1,
31
+ multiview_attention=False,
32
+ cross_domain_attention=False,)
33
+ elif mode == 'r':
34
+ encoder_hidden_states = rearrange(encoder_hidden_states, '(b t) d c-> b (t d) c', t=num_views)
35
+ if self.name in ref_dict:
36
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1)
37
+ res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, num_views=num_views,
38
+ multiview_attention=False,
39
+ cross_domain_attention=False,)
40
+ elif mode == 'm':
41
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict[self.name]], dim=1)
42
+ elif mode == 'n':
43
+ encoder_hidden_states = rearrange(encoder_hidden_states, '(b t) d c-> b (t d) c', t=num_views)
44
+ encoder_hidden_states = torch.cat([encoder_hidden_states], dim=1).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1)
45
+ res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, num_views=num_views,
46
+ multiview_attention=False,
47
+ cross_domain_attention=False,)
48
+ else:
49
+ assert False, mode
50
+ else:
51
+ res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask)
52
+ return res
53
+
54
+ class RefOnlyNoisedUNet(torch.nn.Module):
55
+ def __init__(self, unet, train_sched, val_sched) -> None:
56
+ super().__init__()
57
+ self.unet = unet
58
+ self.train_sched = train_sched
59
+ self.val_sched = val_sched
60
+
61
+ unet_lora_attn_procs = dict()
62
+ for name, _ in unet.attn_processors.items():
63
+ if is_xformers_available():
64
+ default_attn_proc = XFormersMVAttnProcessor()
65
+ else:
66
+ default_attn_proc = MVAttnProcessor()
67
+ unet_lora_attn_procs[name] = ReferenceOnlyAttnProc(
68
+ default_attn_proc, enabled=name.endswith("attn1.processor"), name=name)
69
+
70
+ self.unet.set_attn_processor(unet_lora_attn_procs)
71
+
72
+ def __getattr__(self, name: str):
73
+ try:
74
+ return super().__getattr__(name)
75
+ except AttributeError:
76
+ return getattr(self.unet, name)
77
+
78
+ def forward_cond(self, noisy_cond_lat, timestep, encoder_hidden_states, class_labels, ref_dict, is_cfg_guidance, **kwargs):
79
+ if is_cfg_guidance:
80
+ encoder_hidden_states = encoder_hidden_states[1:]
81
+ class_labels = class_labels[1:]
82
+ self.unet(
83
+ noisy_cond_lat, timestep,
84
+ encoder_hidden_states=encoder_hidden_states,
85
+ class_labels=class_labels,
86
+ cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict),
87
+ **kwargs
88
+ )
89
+
90
+ def forward(
91
+ self, sample, timestep, encoder_hidden_states, class_labels=None,
92
+ *args, cross_attention_kwargs,
93
+ down_block_res_samples=None, mid_block_res_sample=None,
94
+ **kwargs
95
+ ):
96
+ cond_lat = cross_attention_kwargs['cond_lat']
97
+ is_cfg_guidance = cross_attention_kwargs.get('is_cfg_guidance', False)
98
+ noise = torch.randn_like(cond_lat)
99
+ if self.training:
100
+ noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep)
101
+ noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep)
102
+ else:
103
+ noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1))
104
+ noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1))
105
+ ref_dict = {}
106
+ self.forward_cond(
107
+ noisy_cond_lat, timestep,
108
+ encoder_hidden_states, class_labels,
109
+ ref_dict, is_cfg_guidance, **kwargs
110
+ )
111
+ weight_dtype = self.unet.dtype
112
+ return self.unet(
113
+ sample, timestep,
114
+ encoder_hidden_states, *args,
115
+ class_labels=class_labels,
116
+ cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=is_cfg_guidance),
117
+ down_block_additional_residuals=[
118
+ sample.to(dtype=weight_dtype) for sample in down_block_res_samples
119
+ ] if down_block_res_samples is not None else None,
120
+ mid_block_additional_residual=(
121
+ mid_block_res_sample.to(dtype=weight_dtype)
122
+ if mid_block_res_sample is not None else None
123
+ ),
124
+ **kwargs
125
+ )
2D_Stage/tuneavideo/models/resnet.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from einops import rearrange
8
+
9
+
10
+ class InflatedConv3d(nn.Conv2d):
11
+ def forward(self, x):
12
+ video_length = x.shape[2]
13
+
14
+ x = rearrange(x, "b c f h w -> (b f) c h w")
15
+ x = super().forward(x)
16
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
17
+
18
+ return x
19
+
20
+
21
+ class Upsample3D(nn.Module):
22
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
23
+ super().__init__()
24
+ self.channels = channels
25
+ self.out_channels = out_channels or channels
26
+ self.use_conv = use_conv
27
+ self.use_conv_transpose = use_conv_transpose
28
+ self.name = name
29
+
30
+ conv = None
31
+ if use_conv_transpose:
32
+ raise NotImplementedError
33
+ elif use_conv:
34
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
35
+
36
+ if name == "conv":
37
+ self.conv = conv
38
+ else:
39
+ self.Conv2d_0 = conv
40
+
41
+ def forward(self, hidden_states, output_size=None):
42
+ assert hidden_states.shape[1] == self.channels
43
+
44
+ if self.use_conv_transpose:
45
+ raise NotImplementedError
46
+
47
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
48
+ dtype = hidden_states.dtype
49
+ if dtype == torch.bfloat16:
50
+ hidden_states = hidden_states.to(torch.float32)
51
+
52
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
53
+ if hidden_states.shape[0] >= 64:
54
+ hidden_states = hidden_states.contiguous()
55
+
56
+ # if `output_size` is passed we force the interpolation output
57
+ # size and do not make use of `scale_factor=2`
58
+ if output_size is None:
59
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
60
+ else:
61
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
62
+
63
+ # If the input is bfloat16, we cast back to bfloat16
64
+ if dtype == torch.bfloat16:
65
+ hidden_states = hidden_states.to(dtype)
66
+
67
+ if self.use_conv:
68
+ if self.name == "conv":
69
+ hidden_states = self.conv(hidden_states)
70
+ else:
71
+ hidden_states = self.Conv2d_0(hidden_states)
72
+
73
+ return hidden_states
74
+
75
+
76
+ class Downsample3D(nn.Module):
77
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
78
+ super().__init__()
79
+ self.channels = channels
80
+ self.out_channels = out_channels or channels
81
+ self.use_conv = use_conv
82
+ self.padding = padding
83
+ stride = 2
84
+ self.name = name
85
+
86
+ if use_conv:
87
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
88
+ else:
89
+ raise NotImplementedError
90
+
91
+ if name == "conv":
92
+ self.Conv2d_0 = conv
93
+ self.conv = conv
94
+ elif name == "Conv2d_0":
95
+ self.conv = conv
96
+ else:
97
+ self.conv = conv
98
+
99
+ def forward(self, hidden_states):
100
+ assert hidden_states.shape[1] == self.channels
101
+ if self.use_conv and self.padding == 0:
102
+ raise NotImplementedError
103
+
104
+ assert hidden_states.shape[1] == self.channels
105
+ hidden_states = self.conv(hidden_states)
106
+
107
+ return hidden_states
108
+
109
+
110
+ class ResnetBlock3D(nn.Module):
111
+ def __init__(
112
+ self,
113
+ *,
114
+ in_channels,
115
+ out_channels=None,
116
+ conv_shortcut=False,
117
+ dropout=0.0,
118
+ temb_channels=512,
119
+ groups=32,
120
+ groups_out=None,
121
+ pre_norm=True,
122
+ eps=1e-6,
123
+ non_linearity="swish",
124
+ time_embedding_norm="default",
125
+ output_scale_factor=1.0,
126
+ use_in_shortcut=None,
127
+ ):
128
+ super().__init__()
129
+ self.pre_norm = pre_norm
130
+ self.pre_norm = True
131
+ self.in_channels = in_channels
132
+ out_channels = in_channels if out_channels is None else out_channels
133
+ self.out_channels = out_channels
134
+ self.use_conv_shortcut = conv_shortcut
135
+ self.time_embedding_norm = time_embedding_norm
136
+ self.output_scale_factor = output_scale_factor
137
+
138
+ if groups_out is None:
139
+ groups_out = groups
140
+
141
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
142
+
143
+ self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
144
+
145
+ if temb_channels is not None:
146
+ if self.time_embedding_norm == "default":
147
+ time_emb_proj_out_channels = out_channels
148
+ elif self.time_embedding_norm == "scale_shift":
149
+ time_emb_proj_out_channels = out_channels * 2
150
+ else:
151
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
152
+
153
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
154
+ else:
155
+ self.time_emb_proj = None
156
+
157
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
158
+ self.dropout = torch.nn.Dropout(dropout)
159
+ self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
160
+
161
+ if non_linearity == "swish":
162
+ self.nonlinearity = lambda x: F.silu(x)
163
+ elif non_linearity == "mish":
164
+ self.nonlinearity = Mish()
165
+ elif non_linearity == "silu":
166
+ self.nonlinearity = nn.SiLU()
167
+
168
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
169
+
170
+ self.conv_shortcut = None
171
+ if self.use_in_shortcut:
172
+ self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
173
+
174
+ def forward(self, input_tensor, temb):
175
+ hidden_states = input_tensor
176
+
177
+ hidden_states = self.norm1(hidden_states)
178
+ hidden_states = self.nonlinearity(hidden_states)
179
+
180
+ hidden_states = self.conv1(hidden_states)
181
+
182
+ if temb is not None:
183
+ # temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
184
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, :, None, None].permute(0,2,1,3,4)
185
+
186
+ if temb is not None and self.time_embedding_norm == "default":
187
+ hidden_states = hidden_states + temb
188
+
189
+ hidden_states = self.norm2(hidden_states)
190
+
191
+ if temb is not None and self.time_embedding_norm == "scale_shift":
192
+ scale, shift = torch.chunk(temb, 2, dim=1)
193
+ hidden_states = hidden_states * (1 + scale) + shift
194
+
195
+ hidden_states = self.nonlinearity(hidden_states)
196
+
197
+ hidden_states = self.dropout(hidden_states)
198
+ hidden_states = self.conv2(hidden_states)
199
+
200
+ if self.conv_shortcut is not None:
201
+ input_tensor = self.conv_shortcut(input_tensor)
202
+
203
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
204
+
205
+ return output_tensor
206
+
207
+
208
+ class Mish(torch.nn.Module):
209
+ def forward(self, hidden_states):
210
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
2D_Stage/tuneavideo/models/transformer_mv2d.py ADDED
@@ -0,0 +1,1010 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
23
+ from diffusers.utils import BaseOutput, deprecate
24
+ try:
25
+ from diffusers.utils import maybe_allow_in_graph
26
+ except:
27
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
28
+ from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention
29
+ from diffusers.models.embeddings import PatchEmbed
30
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
31
+ from diffusers.models.modeling_utils import ModelMixin
32
+ from diffusers.utils.import_utils import is_xformers_available
33
+
34
+ from einops import rearrange
35
+ import pdb
36
+ import random
37
+
38
+
39
+ if is_xformers_available():
40
+ import xformers
41
+ import xformers.ops
42
+ else:
43
+ xformers = None
44
+
45
+
46
+ @dataclass
47
+ class TransformerMV2DModelOutput(BaseOutput):
48
+ """
49
+ The output of [`Transformer2DModel`].
50
+
51
+ Args:
52
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
53
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
54
+ distributions for the unnoised latent pixels.
55
+ """
56
+
57
+ sample: torch.FloatTensor
58
+
59
+
60
+ class TransformerMV2DModel(ModelMixin, ConfigMixin):
61
+ """
62
+ A 2D Transformer model for image-like data.
63
+
64
+ Parameters:
65
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
66
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
67
+ in_channels (`int`, *optional*):
68
+ The number of channels in the input and output (specify if the input is **continuous**).
69
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
70
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
71
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
72
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
73
+ This is fixed during training since it is used to learn a number of position embeddings.
74
+ num_vector_embeds (`int`, *optional*):
75
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
76
+ Includes the class for the masked latent pixel.
77
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
78
+ num_embeds_ada_norm ( `int`, *optional*):
79
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
80
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
81
+ added to the hidden states.
82
+
83
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
84
+ attention_bias (`bool`, *optional*):
85
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
86
+ """
87
+
88
+ @register_to_config
89
+ def __init__(
90
+ self,
91
+ num_attention_heads: int = 16,
92
+ attention_head_dim: int = 88,
93
+ in_channels: Optional[int] = None,
94
+ out_channels: Optional[int] = None,
95
+ num_layers: int = 1,
96
+ dropout: float = 0.0,
97
+ norm_num_groups: int = 32,
98
+ cross_attention_dim: Optional[int] = None,
99
+ attention_bias: bool = False,
100
+ sample_size: Optional[int] = None,
101
+ num_vector_embeds: Optional[int] = None,
102
+ patch_size: Optional[int] = None,
103
+ activation_fn: str = "geglu",
104
+ num_embeds_ada_norm: Optional[int] = None,
105
+ use_linear_projection: bool = False,
106
+ only_cross_attention: bool = False,
107
+ upcast_attention: bool = False,
108
+ norm_type: str = "layer_norm",
109
+ norm_elementwise_affine: bool = True,
110
+ num_views: int = 1,
111
+ joint_attention: bool=False,
112
+ joint_attention_twice: bool=False,
113
+ multiview_attention: bool=True,
114
+ cross_domain_attention: bool=False
115
+ ):
116
+ super().__init__()
117
+ self.use_linear_projection = use_linear_projection
118
+ self.num_attention_heads = num_attention_heads
119
+ self.attention_head_dim = attention_head_dim
120
+ inner_dim = num_attention_heads * attention_head_dim
121
+
122
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
123
+ # Define whether input is continuous or discrete depending on configuration
124
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
125
+ self.is_input_vectorized = num_vector_embeds is not None
126
+ self.is_input_patches = in_channels is not None and patch_size is not None
127
+
128
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
129
+ deprecation_message = (
130
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
131
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
132
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
133
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
134
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
135
+ )
136
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
137
+ norm_type = "ada_norm"
138
+
139
+ if self.is_input_continuous and self.is_input_vectorized:
140
+ raise ValueError(
141
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
142
+ " sure that either `in_channels` or `num_vector_embeds` is None."
143
+ )
144
+ elif self.is_input_vectorized and self.is_input_patches:
145
+ raise ValueError(
146
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
147
+ " sure that either `num_vector_embeds` or `num_patches` is None."
148
+ )
149
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
150
+ raise ValueError(
151
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
152
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
153
+ )
154
+
155
+ # 2. Define input layers
156
+ if self.is_input_continuous:
157
+ self.in_channels = in_channels
158
+
159
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
160
+ if use_linear_projection:
161
+ self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
162
+ else:
163
+ self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
164
+ elif self.is_input_vectorized:
165
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
166
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
167
+
168
+ self.height = sample_size
169
+ self.width = sample_size
170
+ self.num_vector_embeds = num_vector_embeds
171
+ self.num_latent_pixels = self.height * self.width
172
+
173
+ self.latent_image_embedding = ImagePositionalEmbeddings(
174
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
175
+ )
176
+ elif self.is_input_patches:
177
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
178
+
179
+ self.height = sample_size
180
+ self.width = sample_size
181
+
182
+ self.patch_size = patch_size
183
+ self.pos_embed = PatchEmbed(
184
+ height=sample_size,
185
+ width=sample_size,
186
+ patch_size=patch_size,
187
+ in_channels=in_channels,
188
+ embed_dim=inner_dim,
189
+ )
190
+
191
+ # 3. Define transformers blocks
192
+ self.transformer_blocks = nn.ModuleList(
193
+ [
194
+ BasicMVTransformerBlock(
195
+ inner_dim,
196
+ num_attention_heads,
197
+ attention_head_dim,
198
+ dropout=dropout,
199
+ cross_attention_dim=cross_attention_dim,
200
+ activation_fn=activation_fn,
201
+ num_embeds_ada_norm=num_embeds_ada_norm,
202
+ attention_bias=attention_bias,
203
+ only_cross_attention=only_cross_attention,
204
+ upcast_attention=upcast_attention,
205
+ norm_type=norm_type,
206
+ norm_elementwise_affine=norm_elementwise_affine,
207
+ num_views=num_views,
208
+ joint_attention=joint_attention,
209
+ joint_attention_twice=joint_attention_twice,
210
+ multiview_attention=multiview_attention,
211
+ cross_domain_attention=cross_domain_attention
212
+ )
213
+ for d in range(num_layers)
214
+ ]
215
+ )
216
+
217
+ # 4. Define output layers
218
+ self.out_channels = in_channels if out_channels is None else out_channels
219
+ if self.is_input_continuous:
220
+ # TODO: should use out_channels for continuous projections
221
+ if use_linear_projection:
222
+ self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
223
+ else:
224
+ self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
225
+ elif self.is_input_vectorized:
226
+ self.norm_out = nn.LayerNorm(inner_dim)
227
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
228
+ elif self.is_input_patches:
229
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
230
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
231
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
232
+
233
+ def forward(
234
+ self,
235
+ hidden_states: torch.Tensor,
236
+ encoder_hidden_states: Optional[torch.Tensor] = None,
237
+ timestep: Optional[torch.LongTensor] = None,
238
+ class_labels: Optional[torch.LongTensor] = None,
239
+ cross_attention_kwargs: Dict[str, Any] = None,
240
+ attention_mask: Optional[torch.Tensor] = None,
241
+ encoder_attention_mask: Optional[torch.Tensor] = None,
242
+ return_dict: bool = True,
243
+ ):
244
+ """
245
+ The [`Transformer2DModel`] forward method.
246
+
247
+ Args:
248
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
249
+ Input `hidden_states`.
250
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
251
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
252
+ self-attention.
253
+ timestep ( `torch.LongTensor`, *optional*):
254
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
255
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
256
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
257
+ `AdaLayerZeroNorm`.
258
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
259
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
260
+
261
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
262
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
263
+
264
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
265
+ above. This bias will be added to the cross-attention scores.
266
+ return_dict (`bool`, *optional*, defaults to `True`):
267
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
268
+ tuple.
269
+
270
+ Returns:
271
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
272
+ `tuple` where the first element is the sample tensor.
273
+ """
274
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
275
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
276
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
277
+ # expects mask of shape:
278
+ # [batch, key_tokens]
279
+ # adds singleton query_tokens dimension:
280
+ # [batch, 1, key_tokens]
281
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
282
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
283
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
284
+ if attention_mask is not None and attention_mask.ndim == 2:
285
+ # assume that mask is expressed as:
286
+ # (1 = keep, 0 = discard)
287
+ # convert mask into a bias that can be added to attention scores:
288
+ # (keep = +0, discard = -10000.0)
289
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
290
+ attention_mask = attention_mask.unsqueeze(1)
291
+
292
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
293
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
294
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
295
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
296
+
297
+ # 1. Input
298
+ if self.is_input_continuous:
299
+ batch, _, height, width = hidden_states.shape
300
+ residual = hidden_states
301
+
302
+ hidden_states = self.norm(hidden_states)
303
+ if not self.use_linear_projection:
304
+ hidden_states = self.proj_in(hidden_states)
305
+ inner_dim = hidden_states.shape[1]
306
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
307
+ else:
308
+ inner_dim = hidden_states.shape[1]
309
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
310
+ hidden_states = self.proj_in(hidden_states)
311
+ elif self.is_input_vectorized:
312
+ hidden_states = self.latent_image_embedding(hidden_states)
313
+ elif self.is_input_patches:
314
+ hidden_states = self.pos_embed(hidden_states)
315
+
316
+ # 2. Blocks
317
+ for block in self.transformer_blocks:
318
+ hidden_states = block(
319
+ hidden_states,
320
+ attention_mask=attention_mask,
321
+ encoder_hidden_states=encoder_hidden_states,
322
+ encoder_attention_mask=encoder_attention_mask,
323
+ timestep=timestep,
324
+ cross_attention_kwargs=cross_attention_kwargs,
325
+ class_labels=class_labels,
326
+ )
327
+
328
+ # 3. Output
329
+ if self.is_input_continuous:
330
+ if not self.use_linear_projection:
331
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
332
+ hidden_states = self.proj_out(hidden_states)
333
+ else:
334
+ hidden_states = self.proj_out(hidden_states)
335
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
336
+
337
+ output = hidden_states + residual
338
+ elif self.is_input_vectorized:
339
+ hidden_states = self.norm_out(hidden_states)
340
+ logits = self.out(hidden_states)
341
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
342
+ logits = logits.permute(0, 2, 1)
343
+
344
+ # log(p(x_0))
345
+ output = F.log_softmax(logits.double(), dim=1).float()
346
+ elif self.is_input_patches:
347
+ # TODO: cleanup!
348
+ conditioning = self.transformer_blocks[0].norm1.emb(
349
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
350
+ )
351
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
352
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
353
+ hidden_states = self.proj_out_2(hidden_states)
354
+
355
+ # unpatchify
356
+ height = width = int(hidden_states.shape[1] ** 0.5)
357
+ hidden_states = hidden_states.reshape(
358
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
359
+ )
360
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
361
+ output = hidden_states.reshape(
362
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
363
+ )
364
+
365
+ if not return_dict:
366
+ return (output,)
367
+
368
+ return TransformerMV2DModelOutput(sample=output)
369
+
370
+
371
+ @maybe_allow_in_graph
372
+ class BasicMVTransformerBlock(nn.Module):
373
+ r"""
374
+ A basic Transformer block.
375
+
376
+ Parameters:
377
+ dim (`int`): The number of channels in the input and output.
378
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
379
+ attention_head_dim (`int`): The number of channels in each head.
380
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
381
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
382
+ only_cross_attention (`bool`, *optional*):
383
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
384
+ double_self_attention (`bool`, *optional*):
385
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
386
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
387
+ num_embeds_ada_norm (:
388
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
389
+ attention_bias (:
390
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
391
+ """
392
+
393
+ def __init__(
394
+ self,
395
+ dim: int,
396
+ num_attention_heads: int,
397
+ attention_head_dim: int,
398
+ dropout=0.0,
399
+ cross_attention_dim: Optional[int] = None,
400
+ activation_fn: str = "geglu",
401
+ num_embeds_ada_norm: Optional[int] = None,
402
+ attention_bias: bool = False,
403
+ only_cross_attention: bool = False,
404
+ double_self_attention: bool = False,
405
+ upcast_attention: bool = False,
406
+ norm_elementwise_affine: bool = True,
407
+ norm_type: str = "layer_norm",
408
+ final_dropout: bool = False,
409
+ num_views: int = 1,
410
+ joint_attention: bool = False,
411
+ joint_attention_twice: bool = False,
412
+ multiview_attention: bool = True,
413
+ cross_domain_attention: bool = False
414
+ ):
415
+ super().__init__()
416
+ self.only_cross_attention = only_cross_attention
417
+
418
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
419
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
420
+
421
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
422
+ raise ValueError(
423
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
424
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
425
+ )
426
+
427
+ # Define 3 blocks. Each block has its own normalization layer.
428
+ # 1. Self-Attn
429
+ if self.use_ada_layer_norm:
430
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
431
+ elif self.use_ada_layer_norm_zero:
432
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
433
+ else:
434
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
435
+
436
+ self.multiview_attention = multiview_attention
437
+ self.cross_domain_attention = cross_domain_attention
438
+ # import pdb;pdb.set_trace()
439
+ self.attn1 = CustomAttention(
440
+ query_dim=dim,
441
+ heads=num_attention_heads,
442
+ dim_head=attention_head_dim,
443
+ dropout=dropout,
444
+ bias=attention_bias,
445
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
446
+ upcast_attention=upcast_attention,
447
+ processor=MVAttnProcessor()
448
+ )
449
+
450
+ # 2. Cross-Attn
451
+ if cross_attention_dim is not None or double_self_attention:
452
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
453
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
454
+ # the second cross attention block.
455
+ self.norm2 = (
456
+ AdaLayerNorm(dim, num_embeds_ada_norm)
457
+ if self.use_ada_layer_norm
458
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
459
+ )
460
+ self.attn2 = Attention(
461
+ query_dim=dim,
462
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
463
+ heads=num_attention_heads,
464
+ dim_head=attention_head_dim,
465
+ dropout=dropout,
466
+ bias=attention_bias,
467
+ upcast_attention=upcast_attention,
468
+ ) # is self-attn if encoder_hidden_states is none
469
+ else:
470
+ self.norm2 = None
471
+ self.attn2 = None
472
+
473
+ # 3. Feed-forward
474
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
475
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
476
+
477
+ # let chunk size default to None
478
+ self._chunk_size = None
479
+ self._chunk_dim = 0
480
+
481
+ self.num_views = num_views
482
+
483
+ self.joint_attention = joint_attention
484
+
485
+ if self.joint_attention:
486
+ # Joint task -Attn
487
+ self.attn_joint = CustomJointAttention(
488
+ query_dim=dim,
489
+ heads=num_attention_heads,
490
+ dim_head=attention_head_dim,
491
+ dropout=dropout,
492
+ bias=attention_bias,
493
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
494
+ upcast_attention=upcast_attention,
495
+ processor=JointAttnProcessor()
496
+ )
497
+ nn.init.zeros_(self.attn_joint.to_out[0].weight.data)
498
+ self.norm_joint = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
499
+
500
+
501
+ self.joint_attention_twice = joint_attention_twice
502
+
503
+ if self.joint_attention_twice:
504
+ print("joint twice")
505
+ # Joint task -Attn
506
+ self.attn_joint_twice = CustomJointAttention(
507
+ query_dim=dim,
508
+ heads=num_attention_heads,
509
+ dim_head=attention_head_dim,
510
+ dropout=dropout,
511
+ bias=attention_bias,
512
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
513
+ upcast_attention=upcast_attention,
514
+ processor=JointAttnProcessor()
515
+ )
516
+ nn.init.zeros_(self.attn_joint_twice.to_out[0].weight.data)
517
+ self.norm_joint_twice = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
518
+
519
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
520
+ # Sets chunk feed-forward
521
+ self._chunk_size = chunk_size
522
+ self._chunk_dim = dim
523
+
524
+ def forward(
525
+ self,
526
+ hidden_states: torch.FloatTensor,
527
+ attention_mask: Optional[torch.FloatTensor] = None,
528
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
529
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
530
+ timestep: Optional[torch.LongTensor] = None,
531
+ cross_attention_kwargs: Dict[str, Any] = None,
532
+ class_labels: Optional[torch.LongTensor] = None,
533
+ ):
534
+ assert attention_mask is None # not supported yet
535
+ # Notice that normalization is always applied before the real computation in the following blocks.
536
+ # 1. Self-Attention
537
+ if self.use_ada_layer_norm:
538
+ norm_hidden_states = self.norm1(hidden_states, timestep)
539
+ elif self.use_ada_layer_norm_zero:
540
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
541
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
542
+ )
543
+ else:
544
+ norm_hidden_states = self.norm1(hidden_states)
545
+
546
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
547
+ attn_output = self.attn1(
548
+ norm_hidden_states,
549
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
550
+ attention_mask=attention_mask,
551
+ num_views=self.num_views,
552
+ multiview_attention=self.multiview_attention,
553
+ cross_domain_attention=self.cross_domain_attention,
554
+ **cross_attention_kwargs,
555
+ )
556
+
557
+
558
+ if self.use_ada_layer_norm_zero:
559
+ attn_output = gate_msa.unsqueeze(1) * attn_output
560
+ hidden_states = attn_output + hidden_states
561
+
562
+ # joint attention twice
563
+ if self.joint_attention_twice:
564
+ norm_hidden_states = (
565
+ self.norm_joint_twice(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_twice(hidden_states)
566
+ )
567
+ hidden_states = self.attn_joint_twice(norm_hidden_states) + hidden_states
568
+
569
+ # 2. Cross-Attention
570
+ if self.attn2 is not None:
571
+ norm_hidden_states = (
572
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
573
+ )
574
+ attn_output = self.attn2(
575
+ norm_hidden_states,
576
+ encoder_hidden_states=encoder_hidden_states,
577
+ attention_mask=encoder_attention_mask,
578
+ **cross_attention_kwargs,
579
+ )
580
+ hidden_states = attn_output + hidden_states
581
+
582
+ # 3. Feed-forward
583
+ norm_hidden_states = self.norm3(hidden_states)
584
+
585
+ if self.use_ada_layer_norm_zero:
586
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
587
+
588
+ if self._chunk_size is not None:
589
+ # "feed_forward_chunk_size" can be used to save memory
590
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
591
+ raise ValueError(
592
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
593
+ )
594
+
595
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
596
+ ff_output = torch.cat(
597
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
598
+ dim=self._chunk_dim,
599
+ )
600
+ else:
601
+ ff_output = self.ff(norm_hidden_states)
602
+
603
+ if self.use_ada_layer_norm_zero:
604
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
605
+
606
+ hidden_states = ff_output + hidden_states
607
+
608
+ if self.joint_attention:
609
+ norm_hidden_states = (
610
+ self.norm_joint(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint(hidden_states)
611
+ )
612
+ hidden_states = self.attn_joint(norm_hidden_states) + hidden_states
613
+
614
+ return hidden_states
615
+
616
+
617
+ class CustomAttention(Attention):
618
+ def set_use_memory_efficient_attention_xformers(
619
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
620
+ ):
621
+ processor = XFormersMVAttnProcessor()
622
+ self.set_processor(processor)
623
+ # print("using xformers attention processor")
624
+
625
+
626
+ class CustomJointAttention(Attention):
627
+ def set_use_memory_efficient_attention_xformers(
628
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
629
+ ):
630
+ processor = XFormersJointAttnProcessor()
631
+ self.set_processor(processor)
632
+ # print("using xformers attention processor")
633
+
634
+ class MVAttnProcessor:
635
+ r"""
636
+ Default processor for performing attention-related computations.
637
+ """
638
+
639
+ def __call__(
640
+ self,
641
+ attn: Attention,
642
+ hidden_states,
643
+ encoder_hidden_states=None,
644
+ attention_mask=None,
645
+ temb=None,
646
+ num_views=1,
647
+ multiview_attention=True
648
+ ):
649
+ residual = hidden_states
650
+
651
+ if attn.spatial_norm is not None:
652
+ hidden_states = attn.spatial_norm(hidden_states, temb)
653
+
654
+ input_ndim = hidden_states.ndim
655
+
656
+ if input_ndim == 4:
657
+ batch_size, channel, height, width = hidden_states.shape
658
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
659
+
660
+ batch_size, sequence_length, _ = (
661
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
662
+ )
663
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
664
+
665
+ if attn.group_norm is not None:
666
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
667
+
668
+ query = attn.to_q(hidden_states)
669
+
670
+ if encoder_hidden_states is None:
671
+ encoder_hidden_states = hidden_states
672
+ elif attn.norm_cross:
673
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
674
+
675
+ key = attn.to_k(encoder_hidden_states)
676
+ value = attn.to_v(encoder_hidden_states)
677
+
678
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
679
+ #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
680
+ # pdb.set_trace()
681
+ # multi-view self-attention
682
+ if multiview_attention:
683
+ if num_views <= 6:
684
+ # after use xformer; possible to train with 6 views
685
+ # key = rearrange(key, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
686
+ # value = rearrange(value, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
687
+ key = rearrange(key, '(b t) d c-> b (t d) c', t=num_views).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1)
688
+ value = rearrange(value, '(b t) d c-> b (t d) c', t=num_views).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1)
689
+
690
+ else:# apply sparse attention
691
+ pass
692
+ # print("use sparse attention")
693
+ # # seems that the sparse random sampling cause problems
694
+ # # don't use random sampling, just fix the indexes
695
+ # onekey = rearrange(key, "(b t) d c -> b t d c", t=num_views)
696
+ # onevalue = rearrange(value, "(b t) d c -> b t d c", t=num_views)
697
+ # allkeys = []
698
+ # allvalues = []
699
+ # all_indexes = {
700
+ # 0 : [0, 2, 3, 4],
701
+ # 1: [0, 1, 3, 5],
702
+ # 2: [0, 2, 3, 4],
703
+ # 3: [0, 2, 3, 4],
704
+ # 4: [0, 2, 3, 4],
705
+ # 5: [0, 1, 3, 5]
706
+ # }
707
+ # for jj in range(num_views):
708
+ # # valid_index = [x for x in range(0, num_views) if x!= jj]
709
+ # # indexes = random.sample(valid_index, 3) + [jj] + [0]
710
+ # indexes = all_indexes[jj]
711
+
712
+ # indexes = torch.tensor(indexes).long().to(key.device)
713
+ # allkeys.append(onekey[:, indexes])
714
+ # allvalues.append(onevalue[:, indexes])
715
+ # keys = torch.stack(allkeys, dim=1) # checked, should be dim=1
716
+ # values = torch.stack(allvalues, dim=1)
717
+ # key = rearrange(keys, 'b t f d c -> (b t) (f d) c')
718
+ # value = rearrange(values, 'b t f d c -> (b t) (f d) c')
719
+
720
+
721
+ query = attn.head_to_batch_dim(query).contiguous()
722
+ key = attn.head_to_batch_dim(key).contiguous()
723
+ value = attn.head_to_batch_dim(value).contiguous()
724
+
725
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
726
+ hidden_states = torch.bmm(attention_probs, value)
727
+ hidden_states = attn.batch_to_head_dim(hidden_states)
728
+
729
+ # linear proj
730
+ hidden_states = attn.to_out[0](hidden_states)
731
+ # dropout
732
+ hidden_states = attn.to_out[1](hidden_states)
733
+
734
+ if input_ndim == 4:
735
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
736
+
737
+ if attn.residual_connection:
738
+ hidden_states = hidden_states + residual
739
+
740
+ hidden_states = hidden_states / attn.rescale_output_factor
741
+
742
+ return hidden_states
743
+
744
+
745
+ class XFormersMVAttnProcessor:
746
+ r"""
747
+ Default processor for performing attention-related computations.
748
+ """
749
+
750
+ def __call__(
751
+ self,
752
+ attn: Attention,
753
+ hidden_states,
754
+ encoder_hidden_states=None,
755
+ attention_mask=None,
756
+ temb=None,
757
+ num_views=1.,
758
+ multiview_attention=True,
759
+ cross_domain_attention=False,
760
+ ):
761
+ residual = hidden_states
762
+
763
+ if attn.spatial_norm is not None:
764
+ hidden_states = attn.spatial_norm(hidden_states, temb)
765
+
766
+ input_ndim = hidden_states.ndim
767
+
768
+ if input_ndim == 4:
769
+ batch_size, channel, height, width = hidden_states.shape
770
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
771
+
772
+ batch_size, sequence_length, _ = (
773
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
774
+ )
775
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
776
+
777
+ # from yuancheng; here attention_mask is None
778
+ if attention_mask is not None:
779
+ # expand our mask's singleton query_tokens dimension:
780
+ # [batch*heads, 1, key_tokens] ->
781
+ # [batch*heads, query_tokens, key_tokens]
782
+ # so that it can be added as a bias onto the attention scores that xformers computes:
783
+ # [batch*heads, query_tokens, key_tokens]
784
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
785
+ _, query_tokens, _ = hidden_states.shape
786
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
787
+
788
+ if attn.group_norm is not None:
789
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
790
+
791
+ query = attn.to_q(hidden_states)
792
+
793
+ if encoder_hidden_states is None:
794
+ encoder_hidden_states = hidden_states
795
+ elif attn.norm_cross:
796
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
797
+
798
+ key_raw = attn.to_k(encoder_hidden_states)
799
+ value_raw = attn.to_v(encoder_hidden_states)
800
+
801
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
802
+ #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
803
+ # pdb.set_trace()
804
+ # multi-view self-attention
805
+ if multiview_attention:
806
+ key = rearrange(key_raw, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
807
+ value = rearrange(value_raw, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
808
+
809
+ if cross_domain_attention:
810
+ # memory efficient, cross domain attention
811
+ key_0, key_1 = torch.chunk(key_raw, dim=0, chunks=2) # keys shape (b t) d c
812
+ value_0, value_1 = torch.chunk(value_raw, dim=0, chunks=2)
813
+ key_cross = torch.concat([key_1, key_0], dim=0)
814
+ value_cross = torch.concat([value_1, value_0], dim=0) # shape (b t) d c
815
+ key = torch.cat([key, key_cross], dim=1)
816
+ value = torch.cat([value, value_cross], dim=1) # shape (b t) (t+1 d) c
817
+ else:
818
+ # print("don't use multiview attention.")
819
+ key = key_raw
820
+ value = value_raw
821
+
822
+ query = attn.head_to_batch_dim(query)
823
+ key = attn.head_to_batch_dim(key)
824
+ value = attn.head_to_batch_dim(value)
825
+
826
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
827
+ hidden_states = attn.batch_to_head_dim(hidden_states)
828
+
829
+ # linear proj
830
+ hidden_states = attn.to_out[0](hidden_states)
831
+ # dropout
832
+ hidden_states = attn.to_out[1](hidden_states)
833
+
834
+ if input_ndim == 4:
835
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
836
+
837
+ if attn.residual_connection:
838
+ hidden_states = hidden_states + residual
839
+
840
+ hidden_states = hidden_states / attn.rescale_output_factor
841
+
842
+ return hidden_states
843
+
844
+
845
+
846
+ class XFormersJointAttnProcessor:
847
+ r"""
848
+ Default processor for performing attention-related computations.
849
+ """
850
+
851
+ def __call__(
852
+ self,
853
+ attn: Attention,
854
+ hidden_states,
855
+ encoder_hidden_states=None,
856
+ attention_mask=None,
857
+ temb=None,
858
+ num_tasks=2
859
+ ):
860
+
861
+ residual = hidden_states
862
+
863
+ if attn.spatial_norm is not None:
864
+ hidden_states = attn.spatial_norm(hidden_states, temb)
865
+
866
+ input_ndim = hidden_states.ndim
867
+
868
+ if input_ndim == 4:
869
+ batch_size, channel, height, width = hidden_states.shape
870
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
871
+
872
+ batch_size, sequence_length, _ = (
873
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
874
+ )
875
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
876
+
877
+ # from yuancheng; here attention_mask is None
878
+ if attention_mask is not None:
879
+ # expand our mask's singleton query_tokens dimension:
880
+ # [batch*heads, 1, key_tokens] ->
881
+ # [batch*heads, query_tokens, key_tokens]
882
+ # so that it can be added as a bias onto the attention scores that xformers computes:
883
+ # [batch*heads, query_tokens, key_tokens]
884
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
885
+ _, query_tokens, _ = hidden_states.shape
886
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
887
+
888
+ if attn.group_norm is not None:
889
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
890
+
891
+ query = attn.to_q(hidden_states)
892
+
893
+ if encoder_hidden_states is None:
894
+ encoder_hidden_states = hidden_states
895
+ elif attn.norm_cross:
896
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
897
+
898
+ key = attn.to_k(encoder_hidden_states)
899
+ value = attn.to_v(encoder_hidden_states)
900
+
901
+ assert num_tasks == 2 # only support two tasks now
902
+
903
+ key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
904
+ value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
905
+ key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
906
+ value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
907
+ key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
908
+ value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
909
+
910
+
911
+ query = attn.head_to_batch_dim(query).contiguous()
912
+ key = attn.head_to_batch_dim(key).contiguous()
913
+ value = attn.head_to_batch_dim(value).contiguous()
914
+
915
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
916
+ hidden_states = attn.batch_to_head_dim(hidden_states)
917
+
918
+ # linear proj
919
+ hidden_states = attn.to_out[0](hidden_states)
920
+ # dropout
921
+ hidden_states = attn.to_out[1](hidden_states)
922
+
923
+ if input_ndim == 4:
924
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
925
+
926
+ if attn.residual_connection:
927
+ hidden_states = hidden_states + residual
928
+
929
+ hidden_states = hidden_states / attn.rescale_output_factor
930
+
931
+ return hidden_states
932
+
933
+
934
+ class JointAttnProcessor:
935
+ r"""
936
+ Default processor for performing attention-related computations.
937
+ """
938
+
939
+ def __call__(
940
+ self,
941
+ attn: Attention,
942
+ hidden_states,
943
+ encoder_hidden_states=None,
944
+ attention_mask=None,
945
+ temb=None,
946
+ num_tasks=2
947
+ ):
948
+
949
+ residual = hidden_states
950
+
951
+ if attn.spatial_norm is not None:
952
+ hidden_states = attn.spatial_norm(hidden_states, temb)
953
+
954
+ input_ndim = hidden_states.ndim
955
+
956
+ if input_ndim == 4:
957
+ batch_size, channel, height, width = hidden_states.shape
958
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
959
+
960
+ batch_size, sequence_length, _ = (
961
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
962
+ )
963
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
964
+
965
+
966
+ if attn.group_norm is not None:
967
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
968
+
969
+ query = attn.to_q(hidden_states)
970
+
971
+ if encoder_hidden_states is None:
972
+ encoder_hidden_states = hidden_states
973
+ elif attn.norm_cross:
974
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
975
+
976
+ key = attn.to_k(encoder_hidden_states)
977
+ value = attn.to_v(encoder_hidden_states)
978
+
979
+ assert num_tasks == 2 # only support two tasks now
980
+
981
+ key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
982
+ value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
983
+ key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
984
+ value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
985
+ key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
986
+ value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
987
+
988
+
989
+ query = attn.head_to_batch_dim(query).contiguous()
990
+ key = attn.head_to_batch_dim(key).contiguous()
991
+ value = attn.head_to_batch_dim(value).contiguous()
992
+
993
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
994
+ hidden_states = torch.bmm(attention_probs, value)
995
+ hidden_states = attn.batch_to_head_dim(hidden_states)
996
+
997
+ # linear proj
998
+ hidden_states = attn.to_out[0](hidden_states)
999
+ # dropout
1000
+ hidden_states = attn.to_out[1](hidden_states)
1001
+
1002
+ if input_ndim == 4:
1003
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1004
+
1005
+ if attn.residual_connection:
1006
+ hidden_states = hidden_states + residual
1007
+
1008
+ hidden_states = hidden_states / attn.rescale_output_factor
1009
+
1010
+ return hidden_states
2D_Stage/tuneavideo/models/unet.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import os
7
+ import json
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+
13
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
14
+ from diffusers import ModelMixin
15
+ from diffusers.utils import BaseOutput, logging
16
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
17
+ from .unet_blocks import (
18
+ CrossAttnDownBlock3D,
19
+ CrossAttnUpBlock3D,
20
+ DownBlock3D,
21
+ UNetMidBlock3DCrossAttn,
22
+ UpBlock3D,
23
+ get_down_block,
24
+ get_up_block,
25
+ )
26
+ from .resnet import InflatedConv3d
27
+
28
+
29
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
+
31
+
32
+ @dataclass
33
+ class UNet3DConditionOutput(BaseOutput):
34
+ sample: torch.FloatTensor
35
+
36
+
37
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
38
+ _supports_gradient_checkpointing = True
39
+
40
+ @register_to_config
41
+ def __init__(
42
+ self,
43
+ sample_size: Optional[int] = None,
44
+ in_channels: int = 4,
45
+ out_channels: int = 4,
46
+ center_input_sample: bool = False,
47
+ flip_sin_to_cos: bool = True,
48
+ freq_shift: int = 0,
49
+ down_block_types: Tuple[str] = (
50
+ "CrossAttnDownBlock3D",
51
+ "CrossAttnDownBlock3D",
52
+ "CrossAttnDownBlock3D",
53
+ "DownBlock3D",
54
+ ),
55
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
56
+ up_block_types: Tuple[str] = (
57
+ "UpBlock3D",
58
+ "CrossAttnUpBlock3D",
59
+ "CrossAttnUpBlock3D",
60
+ "CrossAttnUpBlock3D"
61
+ ),
62
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
63
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
64
+ layers_per_block: int = 2,
65
+ downsample_padding: int = 1,
66
+ mid_block_scale_factor: float = 1,
67
+ act_fn: str = "silu",
68
+ norm_num_groups: int = 32,
69
+ norm_eps: float = 1e-5,
70
+ cross_attention_dim: int = 1280,
71
+ attention_head_dim: Union[int, Tuple[int]] = 8,
72
+ dual_cross_attention: bool = False,
73
+ use_linear_projection: bool = False,
74
+ class_embed_type: Optional[str] = None,
75
+ num_class_embeds: Optional[int] = None,
76
+ upcast_attention: bool = False,
77
+ resnet_time_scale_shift: str = "default",
78
+ use_attn_temp: bool = False,
79
+ camera_input_dim: int = 12,
80
+ camera_hidden_dim: int = 320,
81
+ camera_output_dim: int = 1280,
82
+ ):
83
+ super().__init__()
84
+
85
+ self.sample_size = sample_size
86
+ time_embed_dim = block_out_channels[0] * 4
87
+
88
+ # input
89
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
90
+
91
+ # time
92
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
93
+ timestep_input_dim = block_out_channels[0]
94
+
95
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
96
+
97
+ # class embedding
98
+ if class_embed_type is None and num_class_embeds is not None:
99
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
100
+ elif class_embed_type == "timestep":
101
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
102
+ elif class_embed_type == "identity":
103
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
104
+ else:
105
+ self.class_embedding = None
106
+
107
+ # camera metrix
108
+ # def init_linear(l, stddev):
109
+ # nn.init.normal_(l.weight, std=stddev)
110
+ # if l.bias is not None:
111
+ # nn.init.constant_(l.bias, 0.0)
112
+ # self.camera_embedding_1 = nn.Linear(camera_input_dim, camera_hidden_dim)
113
+ # self.camera_embedding_2 = nn.Linear(camera_hidden_dim, camera_output_dim)
114
+ # init_linear(self.camera_embedding_1, 0.25)
115
+ # init_linear(self.camera_embedding_2, 0.25)
116
+
117
+ self.camera_embedding = nn.Sequential(
118
+ nn.Linear(camera_input_dim, time_embed_dim),
119
+ nn.SiLU(),
120
+ nn.Linear(time_embed_dim, time_embed_dim),
121
+ )
122
+
123
+ self.down_blocks = nn.ModuleList([])
124
+ self.mid_block = None
125
+ self.up_blocks = nn.ModuleList([])
126
+
127
+ if isinstance(only_cross_attention, bool):
128
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
129
+
130
+ if isinstance(attention_head_dim, int):
131
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
132
+
133
+ # down
134
+ output_channel = block_out_channels[0]
135
+ for i, down_block_type in enumerate(down_block_types):
136
+ input_channel = output_channel
137
+ output_channel = block_out_channels[i]
138
+ is_final_block = i == len(block_out_channels) - 1
139
+
140
+ down_block = get_down_block(
141
+ down_block_type,
142
+ num_layers=layers_per_block,
143
+ in_channels=input_channel,
144
+ out_channels=output_channel,
145
+ temb_channels=time_embed_dim,
146
+ add_downsample=not is_final_block,
147
+ resnet_eps=norm_eps,
148
+ resnet_act_fn=act_fn,
149
+ resnet_groups=norm_num_groups,
150
+ cross_attention_dim=cross_attention_dim,
151
+ attn_num_head_channels=attention_head_dim[i],
152
+ downsample_padding=downsample_padding,
153
+ dual_cross_attention=dual_cross_attention,
154
+ use_linear_projection=use_linear_projection,
155
+ only_cross_attention=only_cross_attention[i],
156
+ upcast_attention=upcast_attention,
157
+ resnet_time_scale_shift=resnet_time_scale_shift,
158
+ use_attn_temp=use_attn_temp
159
+ )
160
+ self.down_blocks.append(down_block)
161
+
162
+ # mid
163
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
164
+ self.mid_block = UNetMidBlock3DCrossAttn(
165
+ in_channels=block_out_channels[-1],
166
+ temb_channels=time_embed_dim,
167
+ resnet_eps=norm_eps,
168
+ resnet_act_fn=act_fn,
169
+ output_scale_factor=mid_block_scale_factor,
170
+ resnet_time_scale_shift=resnet_time_scale_shift,
171
+ cross_attention_dim=cross_attention_dim,
172
+ attn_num_head_channels=attention_head_dim[-1],
173
+ resnet_groups=norm_num_groups,
174
+ dual_cross_attention=dual_cross_attention,
175
+ use_linear_projection=use_linear_projection,
176
+ upcast_attention=upcast_attention,
177
+ )
178
+ else:
179
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
180
+
181
+ # count how many layers upsample the videos
182
+ self.num_upsamplers = 0
183
+
184
+ # up
185
+ reversed_block_out_channels = list(reversed(block_out_channels))
186
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
187
+ only_cross_attention = list(reversed(only_cross_attention))
188
+ output_channel = reversed_block_out_channels[0]
189
+ for i, up_block_type in enumerate(up_block_types):
190
+ is_final_block = i == len(block_out_channels) - 1
191
+
192
+ prev_output_channel = output_channel
193
+ output_channel = reversed_block_out_channels[i]
194
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
195
+
196
+ # add upsample block for all BUT final layer
197
+ if not is_final_block:
198
+ add_upsample = True
199
+ self.num_upsamplers += 1
200
+ else:
201
+ add_upsample = False
202
+
203
+ up_block = get_up_block(
204
+ up_block_type,
205
+ num_layers=layers_per_block + 1,
206
+ in_channels=input_channel,
207
+ out_channels=output_channel,
208
+ prev_output_channel=prev_output_channel,
209
+ temb_channels=time_embed_dim,
210
+ add_upsample=add_upsample,
211
+ resnet_eps=norm_eps,
212
+ resnet_act_fn=act_fn,
213
+ resnet_groups=norm_num_groups,
214
+ cross_attention_dim=cross_attention_dim,
215
+ attn_num_head_channels=reversed_attention_head_dim[i],
216
+ dual_cross_attention=dual_cross_attention,
217
+ use_linear_projection=use_linear_projection,
218
+ only_cross_attention=only_cross_attention[i],
219
+ upcast_attention=upcast_attention,
220
+ resnet_time_scale_shift=resnet_time_scale_shift,
221
+ use_attn_temp=use_attn_temp,
222
+ )
223
+ self.up_blocks.append(up_block)
224
+ prev_output_channel = output_channel
225
+
226
+ # out
227
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
228
+ self.conv_act = nn.SiLU()
229
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
230
+
231
+ def set_attention_slice(self, slice_size):
232
+ r"""
233
+ Enable sliced attention computation.
234
+
235
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
236
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
237
+
238
+ Args:
239
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
240
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
241
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
242
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
243
+ must be a multiple of `slice_size`.
244
+ """
245
+ sliceable_head_dims = []
246
+
247
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
248
+ if hasattr(module, "set_attention_slice"):
249
+ sliceable_head_dims.append(module.sliceable_head_dim)
250
+
251
+ for child in module.children():
252
+ fn_recursive_retrieve_slicable_dims(child)
253
+
254
+ # retrieve number of attention layers
255
+ for module in self.children():
256
+ fn_recursive_retrieve_slicable_dims(module)
257
+
258
+ num_slicable_layers = len(sliceable_head_dims)
259
+
260
+ if slice_size == "auto":
261
+ # half the attention head size is usually a good trade-off between
262
+ # speed and memory
263
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
264
+ elif slice_size == "max":
265
+ # make smallest slice possible
266
+ slice_size = num_slicable_layers * [1]
267
+
268
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
269
+
270
+ if len(slice_size) != len(sliceable_head_dims):
271
+ raise ValueError(
272
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
273
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
274
+ )
275
+
276
+ for i in range(len(slice_size)):
277
+ size = slice_size[i]
278
+ dim = sliceable_head_dims[i]
279
+ if size is not None and size > dim:
280
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
281
+
282
+ # Recursively walk through all the children.
283
+ # Any children which exposes the set_attention_slice method
284
+ # gets the message
285
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
286
+ if hasattr(module, "set_attention_slice"):
287
+ module.set_attention_slice(slice_size.pop())
288
+
289
+ for child in module.children():
290
+ fn_recursive_set_attention_slice(child, slice_size)
291
+
292
+ reversed_slice_size = list(reversed(slice_size))
293
+ for module in self.children():
294
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
295
+
296
+ def _set_gradient_checkpointing(self, module, value=False):
297
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
298
+ module.gradient_checkpointing = value
299
+
300
+ def forward(
301
+ self,
302
+ sample: torch.FloatTensor,
303
+ timestep: Union[torch.Tensor, float, int],
304
+ encoder_hidden_states: torch.Tensor,
305
+ camera_matrixs: Optional[torch.Tensor] = None,
306
+ class_labels: Optional[torch.Tensor] = None,
307
+ attention_mask: Optional[torch.Tensor] = None,
308
+ return_dict: bool = True,
309
+ ) -> Union[UNet3DConditionOutput, Tuple]:
310
+ r"""
311
+ Args:
312
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
313
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
314
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
315
+ return_dict (`bool`, *optional*, defaults to `True`):
316
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
317
+
318
+ Returns:
319
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
320
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
321
+ returning a tuple, the first element is the sample tensor.
322
+ """
323
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
324
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
325
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
326
+ # on the fly if necessary.
327
+ default_overall_up_factor = 2**self.num_upsamplers
328
+
329
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
330
+ forward_upsample_size = False
331
+ upsample_size = None
332
+
333
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
334
+ logger.info("Forward upsample size to force interpolation output size.")
335
+ forward_upsample_size = True
336
+
337
+ # prepare attention_mask
338
+ if attention_mask is not None:
339
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
340
+ attention_mask = attention_mask.unsqueeze(1)
341
+
342
+ # center input if necessary
343
+ if self.config.center_input_sample:
344
+ sample = 2 * sample - 1.0
345
+ # time
346
+ timesteps = timestep
347
+ if not torch.is_tensor(timesteps):
348
+ # This would be a good case for the `match` statement (Python 3.10+)
349
+ is_mps = sample.device.type == "mps"
350
+ if isinstance(timestep, float):
351
+ dtype = torch.float32 if is_mps else torch.float64
352
+ else:
353
+ dtype = torch.int32 if is_mps else torch.int64
354
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
355
+ elif len(timesteps.shape) == 0:
356
+ timesteps = timesteps[None].to(sample.device)
357
+
358
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
359
+ timesteps = timesteps.expand(sample.shape[0])
360
+
361
+ t_emb = self.time_proj(timesteps)
362
+
363
+ # timesteps does not contain any weights and will always return f32 tensors
364
+ # but time_embedding might actually be running in fp16. so we need to cast here.
365
+ # there might be better ways to encapsulate this.
366
+ t_emb = t_emb.to(dtype=self.dtype)
367
+ emb = self.time_embedding(t_emb) #torch.Size([32, 1280])
368
+ emb = torch.unsqueeze(emb, 1)
369
+ if camera_matrixs is not None:
370
+ # came emb
371
+ cam_emb = self.camera_embedding(camera_matrixs)
372
+ # cam_emb = self.camera_embedding_2(cam_emb)
373
+ emb = emb.repeat(1,cam_emb.shape[1],1) #torch.Size([32, 4, 1280])
374
+ emb = emb + cam_emb
375
+
376
+ # import pdb;pdb.set_trace()
377
+ if self.class_embedding is not None:
378
+ # if class_labels is None:
379
+ # raise ValueError("class_labels should be provided when num_class_embeds > 0")
380
+ if class_labels is not None:
381
+
382
+ if self.config.class_embed_type == "timestep":
383
+ class_labels = self.time_proj(class_labels)
384
+
385
+ class_emb = self.class_embedding(class_labels)
386
+ emb = emb + class_emb
387
+
388
+ # pre-process
389
+ sample = self.conv_in(sample)
390
+
391
+ # down
392
+ down_block_res_samples = (sample,)
393
+ for downsample_block in self.down_blocks:
394
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
395
+ sample, res_samples = downsample_block(
396
+ hidden_states=sample,
397
+ temb=emb,
398
+ encoder_hidden_states=encoder_hidden_states,
399
+ attention_mask=attention_mask,
400
+ )
401
+ else:
402
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
403
+
404
+ down_block_res_samples += res_samples
405
+
406
+ # mid
407
+ sample = self.mid_block(
408
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
409
+ )
410
+
411
+ # up
412
+ for i, upsample_block in enumerate(self.up_blocks):
413
+ is_final_block = i == len(self.up_blocks) - 1
414
+
415
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
416
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
417
+
418
+ # if we have not reached the final block and need to forward the
419
+ # upsample size, we do it here
420
+ if not is_final_block and forward_upsample_size:
421
+ upsample_size = down_block_res_samples[-1].shape[2:]
422
+
423
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
424
+ sample = upsample_block(
425
+ hidden_states=sample,
426
+ temb=emb,
427
+ res_hidden_states_tuple=res_samples,
428
+ encoder_hidden_states=encoder_hidden_states,
429
+ upsample_size=upsample_size,
430
+ attention_mask=attention_mask,
431
+ )
432
+ else:
433
+ sample = upsample_block(
434
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
435
+ )
436
+ # post-process
437
+ sample = self.conv_norm_out(sample)
438
+ sample = self.conv_act(sample)
439
+ sample = self.conv_out(sample)
440
+
441
+ if not return_dict:
442
+ return (sample,)
443
+
444
+ return UNet3DConditionOutput(sample=sample)
445
+
446
+ @classmethod
447
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None):
448
+ if subfolder is not None:
449
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
450
+
451
+ config_file = os.path.join(pretrained_model_path, 'config.json')
452
+ if not os.path.isfile(config_file):
453
+ raise RuntimeError(f"{config_file} does not exist")
454
+ with open(config_file, "r") as f:
455
+ config = json.load(f)
456
+ config["_class_name"] = cls.__name__
457
+ config["down_block_types"] = [
458
+ "CrossAttnDownBlock3D",
459
+ "CrossAttnDownBlock3D",
460
+ "CrossAttnDownBlock3D",
461
+ "DownBlock3D"
462
+ ]
463
+ config["up_block_types"] = [
464
+ "UpBlock3D",
465
+ "CrossAttnUpBlock3D",
466
+ "CrossAttnUpBlock3D",
467
+ "CrossAttnUpBlock3D"
468
+ ]
469
+
470
+ from diffusers.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME
471
+ # model = cls.from_config(config)
472
+ # model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
473
+ # if not os.path.isfile(model_file):
474
+ # raise RuntimeError(f"{model_file} does not exist")
475
+ # state_dict = torch.load(model_file, map_location="cpu")
476
+
477
+ import safetensors
478
+ model = cls.from_config(config)
479
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
480
+ if not os.path.isfile(model_file):
481
+ model_file = os.path.join(pretrained_model_path, SAFETENSORS_WEIGHTS_NAME)
482
+ if not os.path.isfile(model_file):
483
+ raise RuntimeError(f"{model_file} does not exist")
484
+ else:
485
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
486
+ else:
487
+ state_dict = torch.load(model_file, map_location="cpu")
488
+
489
+ for k, v in model.state_dict().items():
490
+ if '_temp.' in k or 'camera_embedding' in k or 'class_embedding' in k:
491
+ state_dict.update({k: v})
492
+ for k in list(state_dict.keys()):
493
+ if 'camera_embedding_' in k:
494
+ v = state_dict.pop(k)
495
+ model.load_state_dict(state_dict)
496
+
497
+ return model
2D_Stage/tuneavideo/models/unet_blocks.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ # from .attention import Transformer3DModel
7
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
8
+
9
+
10
+ def get_down_block(
11
+ down_block_type,
12
+ num_layers,
13
+ in_channels,
14
+ out_channels,
15
+ temb_channels,
16
+ add_downsample,
17
+ resnet_eps,
18
+ resnet_act_fn,
19
+ attn_num_head_channels,
20
+ resnet_groups=None,
21
+ cross_attention_dim=None,
22
+ downsample_padding=None,
23
+ dual_cross_attention=False,
24
+ use_linear_projection=False,
25
+ only_cross_attention=False,
26
+ upcast_attention=False,
27
+ resnet_time_scale_shift="default",
28
+ use_attn_temp=False,
29
+ ):
30
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
31
+ if down_block_type == "DownBlock3D":
32
+ return DownBlock3D(
33
+ num_layers=num_layers,
34
+ in_channels=in_channels,
35
+ out_channels=out_channels,
36
+ temb_channels=temb_channels,
37
+ add_downsample=add_downsample,
38
+ resnet_eps=resnet_eps,
39
+ resnet_act_fn=resnet_act_fn,
40
+ resnet_groups=resnet_groups,
41
+ downsample_padding=downsample_padding,
42
+ resnet_time_scale_shift=resnet_time_scale_shift,
43
+ )
44
+ elif down_block_type == "CrossAttnDownBlock3D":
45
+ if cross_attention_dim is None:
46
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
47
+ return CrossAttnDownBlock3D(
48
+ num_layers=num_layers,
49
+ in_channels=in_channels,
50
+ out_channels=out_channels,
51
+ temb_channels=temb_channels,
52
+ add_downsample=add_downsample,
53
+ resnet_eps=resnet_eps,
54
+ resnet_act_fn=resnet_act_fn,
55
+ resnet_groups=resnet_groups,
56
+ downsample_padding=downsample_padding,
57
+ cross_attention_dim=cross_attention_dim,
58
+ attn_num_head_channels=attn_num_head_channels,
59
+ dual_cross_attention=dual_cross_attention,
60
+ use_linear_projection=use_linear_projection,
61
+ only_cross_attention=only_cross_attention,
62
+ upcast_attention=upcast_attention,
63
+ resnet_time_scale_shift=resnet_time_scale_shift,
64
+ use_attn_temp=use_attn_temp,
65
+ )
66
+ raise ValueError(f"{down_block_type} does not exist.")
67
+
68
+
69
+ def get_up_block(
70
+ up_block_type,
71
+ num_layers,
72
+ in_channels,
73
+ out_channels,
74
+ prev_output_channel,
75
+ temb_channels,
76
+ add_upsample,
77
+ resnet_eps,
78
+ resnet_act_fn,
79
+ attn_num_head_channels,
80
+ resnet_groups=None,
81
+ cross_attention_dim=None,
82
+ dual_cross_attention=False,
83
+ use_linear_projection=False,
84
+ only_cross_attention=False,
85
+ upcast_attention=False,
86
+ resnet_time_scale_shift="default",
87
+ use_attn_temp=False,
88
+ ):
89
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
90
+ if up_block_type == "UpBlock3D":
91
+ return UpBlock3D(
92
+ num_layers=num_layers,
93
+ in_channels=in_channels,
94
+ out_channels=out_channels,
95
+ prev_output_channel=prev_output_channel,
96
+ temb_channels=temb_channels,
97
+ add_upsample=add_upsample,
98
+ resnet_eps=resnet_eps,
99
+ resnet_act_fn=resnet_act_fn,
100
+ resnet_groups=resnet_groups,
101
+ resnet_time_scale_shift=resnet_time_scale_shift,
102
+ )
103
+ elif up_block_type == "CrossAttnUpBlock3D":
104
+ if cross_attention_dim is None:
105
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
106
+ return CrossAttnUpBlock3D(
107
+ num_layers=num_layers,
108
+ in_channels=in_channels,
109
+ out_channels=out_channels,
110
+ prev_output_channel=prev_output_channel,
111
+ temb_channels=temb_channels,
112
+ add_upsample=add_upsample,
113
+ resnet_eps=resnet_eps,
114
+ resnet_act_fn=resnet_act_fn,
115
+ resnet_groups=resnet_groups,
116
+ cross_attention_dim=cross_attention_dim,
117
+ attn_num_head_channels=attn_num_head_channels,
118
+ dual_cross_attention=dual_cross_attention,
119
+ use_linear_projection=use_linear_projection,
120
+ only_cross_attention=only_cross_attention,
121
+ upcast_attention=upcast_attention,
122
+ resnet_time_scale_shift=resnet_time_scale_shift,
123
+ use_attn_temp=use_attn_temp,
124
+ )
125
+ raise ValueError(f"{up_block_type} does not exist.")
126
+
127
+
128
+ class UNetMidBlock3DCrossAttn(nn.Module):
129
+ def __init__(
130
+ self,
131
+ in_channels: int,
132
+ temb_channels: int,
133
+ dropout: float = 0.0,
134
+ num_layers: int = 1,
135
+ resnet_eps: float = 1e-6,
136
+ resnet_time_scale_shift: str = "default",
137
+ resnet_act_fn: str = "swish",
138
+ resnet_groups: int = 32,
139
+ resnet_pre_norm: bool = True,
140
+ attn_num_head_channels=1,
141
+ output_scale_factor=1.0,
142
+ cross_attention_dim=1280,
143
+ dual_cross_attention=False,
144
+ use_linear_projection=False,
145
+ upcast_attention=False,
146
+ ):
147
+ super().__init__()
148
+
149
+ self.has_cross_attention = True
150
+ self.attn_num_head_channels = attn_num_head_channels
151
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
152
+
153
+ # there is always at least one resnet
154
+ resnets = [
155
+ ResnetBlock3D(
156
+ in_channels=in_channels,
157
+ out_channels=in_channels,
158
+ temb_channels=temb_channels,
159
+ eps=resnet_eps,
160
+ groups=resnet_groups,
161
+ dropout=dropout,
162
+ time_embedding_norm=resnet_time_scale_shift,
163
+ non_linearity=resnet_act_fn,
164
+ output_scale_factor=output_scale_factor,
165
+ pre_norm=resnet_pre_norm,
166
+ )
167
+ ]
168
+ attentions = []
169
+
170
+ for _ in range(num_layers):
171
+ if dual_cross_attention:
172
+ raise NotImplementedError
173
+ attentions.append(
174
+ Transformer3DModel(
175
+ attn_num_head_channels,
176
+ in_channels // attn_num_head_channels,
177
+ in_channels=in_channels,
178
+ num_layers=1,
179
+ cross_attention_dim=cross_attention_dim,
180
+ norm_num_groups=resnet_groups,
181
+ use_linear_projection=use_linear_projection,
182
+ upcast_attention=upcast_attention,
183
+ )
184
+ )
185
+ resnets.append(
186
+ ResnetBlock3D(
187
+ in_channels=in_channels,
188
+ out_channels=in_channels,
189
+ temb_channels=temb_channels,
190
+ eps=resnet_eps,
191
+ groups=resnet_groups,
192
+ dropout=dropout,
193
+ time_embedding_norm=resnet_time_scale_shift,
194
+ non_linearity=resnet_act_fn,
195
+ output_scale_factor=output_scale_factor,
196
+ pre_norm=resnet_pre_norm,
197
+ )
198
+ )
199
+
200
+ self.attentions = nn.ModuleList(attentions)
201
+ self.resnets = nn.ModuleList(resnets)
202
+
203
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
204
+ hidden_states = self.resnets[0](hidden_states, temb)
205
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
206
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
207
+ hidden_states = resnet(hidden_states, temb)
208
+
209
+ return hidden_states
210
+
211
+
212
+ class CrossAttnDownBlock3D(nn.Module):
213
+ def __init__(
214
+ self,
215
+ in_channels: int,
216
+ out_channels: int,
217
+ temb_channels: int,
218
+ dropout: float = 0.0,
219
+ num_layers: int = 1,
220
+ resnet_eps: float = 1e-6,
221
+ resnet_time_scale_shift: str = "default",
222
+ resnet_act_fn: str = "swish",
223
+ resnet_groups: int = 32,
224
+ resnet_pre_norm: bool = True,
225
+ attn_num_head_channels=1,
226
+ cross_attention_dim=1280,
227
+ output_scale_factor=1.0,
228
+ downsample_padding=1,
229
+ add_downsample=True,
230
+ dual_cross_attention=False,
231
+ use_linear_projection=False,
232
+ only_cross_attention=False,
233
+ upcast_attention=False,
234
+ use_attn_temp=False,
235
+ ):
236
+ super().__init__()
237
+ resnets = []
238
+ attentions = []
239
+
240
+ self.has_cross_attention = True
241
+ self.attn_num_head_channels = attn_num_head_channels
242
+
243
+ for i in range(num_layers):
244
+ in_channels = in_channels if i == 0 else out_channels
245
+ resnets.append(
246
+ ResnetBlock3D(
247
+ in_channels=in_channels,
248
+ out_channels=out_channels,
249
+ temb_channels=temb_channels,
250
+ eps=resnet_eps,
251
+ groups=resnet_groups,
252
+ dropout=dropout,
253
+ time_embedding_norm=resnet_time_scale_shift,
254
+ non_linearity=resnet_act_fn,
255
+ output_scale_factor=output_scale_factor,
256
+ pre_norm=resnet_pre_norm,
257
+ )
258
+ )
259
+ if dual_cross_attention:
260
+ raise NotImplementedError
261
+ attentions.append(
262
+ Transformer3DModel(
263
+ attn_num_head_channels,
264
+ out_channels // attn_num_head_channels,
265
+ in_channels=out_channels,
266
+ num_layers=1,
267
+ cross_attention_dim=cross_attention_dim,
268
+ norm_num_groups=resnet_groups,
269
+ use_linear_projection=use_linear_projection,
270
+ only_cross_attention=only_cross_attention,
271
+ upcast_attention=upcast_attention,
272
+ use_attn_temp=use_attn_temp,
273
+ )
274
+ )
275
+ self.attentions = nn.ModuleList(attentions)
276
+ self.resnets = nn.ModuleList(resnets)
277
+
278
+ if add_downsample:
279
+ self.downsamplers = nn.ModuleList(
280
+ [
281
+ Downsample3D(
282
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
283
+ )
284
+ ]
285
+ )
286
+ else:
287
+ self.downsamplers = None
288
+
289
+ self.gradient_checkpointing = False
290
+
291
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
292
+ output_states = ()
293
+
294
+ for resnet, attn in zip(self.resnets, self.attentions):
295
+ if self.training and self.gradient_checkpointing:
296
+
297
+ def create_custom_forward(module, return_dict=None):
298
+ def custom_forward(*inputs):
299
+ if return_dict is not None:
300
+ return module(*inputs, return_dict=return_dict)
301
+ else:
302
+ return module(*inputs)
303
+
304
+ return custom_forward
305
+
306
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
307
+ hidden_states = torch.utils.checkpoint.checkpoint(
308
+ create_custom_forward(attn, return_dict=False),
309
+ hidden_states,
310
+ encoder_hidden_states,
311
+ )[0]
312
+ else:
313
+ hidden_states = resnet(hidden_states, temb)
314
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
315
+
316
+ output_states += (hidden_states,)
317
+
318
+ if self.downsamplers is not None:
319
+ for downsampler in self.downsamplers:
320
+ hidden_states = downsampler(hidden_states)
321
+
322
+ output_states += (hidden_states,)
323
+
324
+ return hidden_states, output_states
325
+
326
+
327
+ class DownBlock3D(nn.Module):
328
+ def __init__(
329
+ self,
330
+ in_channels: int,
331
+ out_channels: int,
332
+ temb_channels: int,
333
+ dropout: float = 0.0,
334
+ num_layers: int = 1,
335
+ resnet_eps: float = 1e-6,
336
+ resnet_time_scale_shift: str = "default",
337
+ resnet_act_fn: str = "swish",
338
+ resnet_groups: int = 32,
339
+ resnet_pre_norm: bool = True,
340
+ output_scale_factor=1.0,
341
+ add_downsample=True,
342
+ downsample_padding=1,
343
+ ):
344
+ super().__init__()
345
+ resnets = []
346
+
347
+ for i in range(num_layers):
348
+ in_channels = in_channels if i == 0 else out_channels
349
+ resnets.append(
350
+ ResnetBlock3D(
351
+ in_channels=in_channels,
352
+ out_channels=out_channels,
353
+ temb_channels=temb_channels,
354
+ eps=resnet_eps,
355
+ groups=resnet_groups,
356
+ dropout=dropout,
357
+ time_embedding_norm=resnet_time_scale_shift,
358
+ non_linearity=resnet_act_fn,
359
+ output_scale_factor=output_scale_factor,
360
+ pre_norm=resnet_pre_norm,
361
+ )
362
+ )
363
+
364
+ self.resnets = nn.ModuleList(resnets)
365
+
366
+ if add_downsample:
367
+ self.downsamplers = nn.ModuleList(
368
+ [
369
+ Downsample3D(
370
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
371
+ )
372
+ ]
373
+ )
374
+ else:
375
+ self.downsamplers = None
376
+
377
+ self.gradient_checkpointing = False
378
+
379
+ def forward(self, hidden_states, temb=None):
380
+ output_states = ()
381
+
382
+ for resnet in self.resnets:
383
+ if self.training and self.gradient_checkpointing:
384
+
385
+ def create_custom_forward(module):
386
+ def custom_forward(*inputs):
387
+ return module(*inputs)
388
+
389
+ return custom_forward
390
+
391
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
392
+ else:
393
+ hidden_states = resnet(hidden_states, temb)
394
+
395
+ output_states += (hidden_states,)
396
+
397
+ if self.downsamplers is not None:
398
+ for downsampler in self.downsamplers:
399
+ hidden_states = downsampler(hidden_states)
400
+
401
+ output_states += (hidden_states,)
402
+
403
+ return hidden_states, output_states
404
+
405
+
406
+ class CrossAttnUpBlock3D(nn.Module):
407
+ def __init__(
408
+ self,
409
+ in_channels: int,
410
+ out_channels: int,
411
+ prev_output_channel: int,
412
+ temb_channels: int,
413
+ dropout: float = 0.0,
414
+ num_layers: int = 1,
415
+ resnet_eps: float = 1e-6,
416
+ resnet_time_scale_shift: str = "default",
417
+ resnet_act_fn: str = "swish",
418
+ resnet_groups: int = 32,
419
+ resnet_pre_norm: bool = True,
420
+ attn_num_head_channels=1,
421
+ cross_attention_dim=1280,
422
+ output_scale_factor=1.0,
423
+ add_upsample=True,
424
+ dual_cross_attention=False,
425
+ use_linear_projection=False,
426
+ only_cross_attention=False,
427
+ upcast_attention=False,
428
+ use_attn_temp=False,
429
+ ):
430
+ super().__init__()
431
+ resnets = []
432
+ attentions = []
433
+
434
+ self.has_cross_attention = True
435
+ self.attn_num_head_channels = attn_num_head_channels
436
+
437
+ for i in range(num_layers):
438
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
439
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
440
+
441
+ resnets.append(
442
+ ResnetBlock3D(
443
+ in_channels=resnet_in_channels + res_skip_channels,
444
+ out_channels=out_channels,
445
+ temb_channels=temb_channels,
446
+ eps=resnet_eps,
447
+ groups=resnet_groups,
448
+ dropout=dropout,
449
+ time_embedding_norm=resnet_time_scale_shift,
450
+ non_linearity=resnet_act_fn,
451
+ output_scale_factor=output_scale_factor,
452
+ pre_norm=resnet_pre_norm,
453
+ )
454
+ )
455
+ if dual_cross_attention:
456
+ raise NotImplementedError
457
+ attentions.append(
458
+ Transformer3DModel(
459
+ attn_num_head_channels,
460
+ out_channels // attn_num_head_channels,
461
+ in_channels=out_channels,
462
+ num_layers=1,
463
+ cross_attention_dim=cross_attention_dim,
464
+ norm_num_groups=resnet_groups,
465
+ use_linear_projection=use_linear_projection,
466
+ only_cross_attention=only_cross_attention,
467
+ upcast_attention=upcast_attention,
468
+ use_attn_temp=use_attn_temp,
469
+ )
470
+ )
471
+
472
+ self.attentions = nn.ModuleList(attentions)
473
+ self.resnets = nn.ModuleList(resnets)
474
+
475
+ if add_upsample:
476
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
477
+ else:
478
+ self.upsamplers = None
479
+
480
+ self.gradient_checkpointing = False
481
+
482
+ def forward(
483
+ self,
484
+ hidden_states,
485
+ res_hidden_states_tuple,
486
+ temb=None,
487
+ encoder_hidden_states=None,
488
+ upsample_size=None,
489
+ attention_mask=None,
490
+ ):
491
+ for resnet, attn in zip(self.resnets, self.attentions):
492
+ # pop res hidden states
493
+ res_hidden_states = res_hidden_states_tuple[-1]
494
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
495
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
496
+
497
+ if self.training and self.gradient_checkpointing:
498
+
499
+ def create_custom_forward(module, return_dict=None):
500
+ def custom_forward(*inputs):
501
+ if return_dict is not None:
502
+ return module(*inputs, return_dict=return_dict)
503
+ else:
504
+ return module(*inputs)
505
+
506
+ return custom_forward
507
+
508
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
509
+ hidden_states = torch.utils.checkpoint.checkpoint(
510
+ create_custom_forward(attn, return_dict=False),
511
+ hidden_states,
512
+ encoder_hidden_states,
513
+ )[0]
514
+ else:
515
+ hidden_states = resnet(hidden_states, temb)
516
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
517
+
518
+ if self.upsamplers is not None:
519
+ for upsampler in self.upsamplers:
520
+ hidden_states = upsampler(hidden_states, upsample_size)
521
+
522
+ return hidden_states
523
+
524
+
525
+ class UpBlock3D(nn.Module):
526
+ def __init__(
527
+ self,
528
+ in_channels: int,
529
+ prev_output_channel: int,
530
+ out_channels: int,
531
+ temb_channels: int,
532
+ dropout: float = 0.0,
533
+ num_layers: int = 1,
534
+ resnet_eps: float = 1e-6,
535
+ resnet_time_scale_shift: str = "default",
536
+ resnet_act_fn: str = "swish",
537
+ resnet_groups: int = 32,
538
+ resnet_pre_norm: bool = True,
539
+ output_scale_factor=1.0,
540
+ add_upsample=True,
541
+ ):
542
+ super().__init__()
543
+ resnets = []
544
+
545
+ for i in range(num_layers):
546
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
547
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
548
+
549
+ resnets.append(
550
+ ResnetBlock3D(
551
+ in_channels=resnet_in_channels + res_skip_channels,
552
+ out_channels=out_channels,
553
+ temb_channels=temb_channels,
554
+ eps=resnet_eps,
555
+ groups=resnet_groups,
556
+ dropout=dropout,
557
+ time_embedding_norm=resnet_time_scale_shift,
558
+ non_linearity=resnet_act_fn,
559
+ output_scale_factor=output_scale_factor,
560
+ pre_norm=resnet_pre_norm,
561
+ )
562
+ )
563
+
564
+ self.resnets = nn.ModuleList(resnets)
565
+
566
+ if add_upsample:
567
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
568
+ else:
569
+ self.upsamplers = None
570
+
571
+ self.gradient_checkpointing = False
572
+
573
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
574
+ for resnet in self.resnets:
575
+ # pop res hidden states
576
+ res_hidden_states = res_hidden_states_tuple[-1]
577
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
578
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
579
+
580
+ if self.training and self.gradient_checkpointing:
581
+
582
+ def create_custom_forward(module):
583
+ def custom_forward(*inputs):
584
+ return module(*inputs)
585
+
586
+ return custom_forward
587
+
588
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
589
+ else:
590
+ hidden_states = resnet(hidden_states, temb)
591
+
592
+ if self.upsamplers is not None:
593
+ for upsampler in self.upsamplers:
594
+ hidden_states = upsampler(hidden_states, upsample_size)
595
+
596
+ return hidden_states
2D_Stage/tuneavideo/models/unet_mv2d_blocks.py ADDED
@@ -0,0 +1,926 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional, Tuple
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.utils import is_torch_version, logging
22
+ # from diffusers.models.attention import AdaGroupNorm
23
+ from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
24
+ from diffusers.models.dual_transformer_2d import DualTransformer2DModel
25
+ from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
26
+ from tuneavideo.models.transformer_mv2d import TransformerMV2DModel
27
+
28
+ from diffusers.models.unet_2d_blocks import DownBlock2D, ResnetDownsampleBlock2D, AttnDownBlock2D, CrossAttnDownBlock2D, SimpleCrossAttnDownBlock2D, SkipDownBlock2D, AttnSkipDownBlock2D, DownEncoderBlock2D, AttnDownEncoderBlock2D, KDownBlock2D, KCrossAttnDownBlock2D
29
+ from diffusers.models.unet_2d_blocks import UpBlock2D, ResnetUpsampleBlock2D, CrossAttnUpBlock2D, SimpleCrossAttnUpBlock2D, AttnUpBlock2D, SkipUpBlock2D, AttnSkipUpBlock2D, UpDecoderBlock2D, AttnUpDecoderBlock2D, KUpBlock2D, KCrossAttnUpBlock2D
30
+
31
+
32
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
+
34
+
35
+ def get_down_block(
36
+ down_block_type,
37
+ num_layers,
38
+ in_channels,
39
+ out_channels,
40
+ temb_channels,
41
+ add_downsample,
42
+ resnet_eps,
43
+ resnet_act_fn,
44
+ transformer_layers_per_block=1,
45
+ num_attention_heads=None,
46
+ resnet_groups=None,
47
+ cross_attention_dim=None,
48
+ downsample_padding=None,
49
+ dual_cross_attention=False,
50
+ use_linear_projection=False,
51
+ only_cross_attention=False,
52
+ upcast_attention=False,
53
+ resnet_time_scale_shift="default",
54
+ resnet_skip_time_act=False,
55
+ resnet_out_scale_factor=1.0,
56
+ cross_attention_norm=None,
57
+ attention_head_dim=None,
58
+ downsample_type=None,
59
+ num_views=1,
60
+ joint_attention: bool = False,
61
+ joint_attention_twice: bool = False,
62
+ multiview_attention: bool = True,
63
+ cross_domain_attention: bool=False
64
+ ):
65
+ # If attn head dim is not defined, we default it to the number of heads
66
+ if attention_head_dim is None:
67
+ logger.warn(
68
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
69
+ )
70
+ attention_head_dim = num_attention_heads
71
+
72
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
73
+ if down_block_type == "DownBlock2D":
74
+ return DownBlock2D(
75
+ num_layers=num_layers,
76
+ in_channels=in_channels,
77
+ out_channels=out_channels,
78
+ temb_channels=temb_channels,
79
+ add_downsample=add_downsample,
80
+ resnet_eps=resnet_eps,
81
+ resnet_act_fn=resnet_act_fn,
82
+ resnet_groups=resnet_groups,
83
+ downsample_padding=downsample_padding,
84
+ resnet_time_scale_shift=resnet_time_scale_shift,
85
+ )
86
+ elif down_block_type == "ResnetDownsampleBlock2D":
87
+ return ResnetDownsampleBlock2D(
88
+ num_layers=num_layers,
89
+ in_channels=in_channels,
90
+ out_channels=out_channels,
91
+ temb_channels=temb_channels,
92
+ add_downsample=add_downsample,
93
+ resnet_eps=resnet_eps,
94
+ resnet_act_fn=resnet_act_fn,
95
+ resnet_groups=resnet_groups,
96
+ resnet_time_scale_shift=resnet_time_scale_shift,
97
+ skip_time_act=resnet_skip_time_act,
98
+ output_scale_factor=resnet_out_scale_factor,
99
+ )
100
+ elif down_block_type == "AttnDownBlock2D":
101
+ if add_downsample is False:
102
+ downsample_type = None
103
+ else:
104
+ downsample_type = downsample_type or "conv" # default to 'conv'
105
+ return AttnDownBlock2D(
106
+ num_layers=num_layers,
107
+ in_channels=in_channels,
108
+ out_channels=out_channels,
109
+ temb_channels=temb_channels,
110
+ resnet_eps=resnet_eps,
111
+ resnet_act_fn=resnet_act_fn,
112
+ resnet_groups=resnet_groups,
113
+ downsample_padding=downsample_padding,
114
+ attention_head_dim=attention_head_dim,
115
+ resnet_time_scale_shift=resnet_time_scale_shift,
116
+ downsample_type=downsample_type,
117
+ )
118
+ elif down_block_type == "CrossAttnDownBlock2D":
119
+ if cross_attention_dim is None:
120
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
121
+ return CrossAttnDownBlock2D(
122
+ num_layers=num_layers,
123
+ transformer_layers_per_block=transformer_layers_per_block,
124
+ in_channels=in_channels,
125
+ out_channels=out_channels,
126
+ temb_channels=temb_channels,
127
+ add_downsample=add_downsample,
128
+ resnet_eps=resnet_eps,
129
+ resnet_act_fn=resnet_act_fn,
130
+ resnet_groups=resnet_groups,
131
+ downsample_padding=downsample_padding,
132
+ cross_attention_dim=cross_attention_dim,
133
+ num_attention_heads=num_attention_heads,
134
+ dual_cross_attention=dual_cross_attention,
135
+ use_linear_projection=use_linear_projection,
136
+ only_cross_attention=only_cross_attention,
137
+ upcast_attention=upcast_attention,
138
+ resnet_time_scale_shift=resnet_time_scale_shift,
139
+ )
140
+ # custom MV2D attention block
141
+ elif down_block_type == "CrossAttnDownBlockMV2D":
142
+ if cross_attention_dim is None:
143
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMV2D")
144
+ return CrossAttnDownBlockMV2D(
145
+ num_layers=num_layers,
146
+ transformer_layers_per_block=transformer_layers_per_block,
147
+ in_channels=in_channels,
148
+ out_channels=out_channels,
149
+ temb_channels=temb_channels,
150
+ add_downsample=add_downsample,
151
+ resnet_eps=resnet_eps,
152
+ resnet_act_fn=resnet_act_fn,
153
+ resnet_groups=resnet_groups,
154
+ downsample_padding=downsample_padding,
155
+ cross_attention_dim=cross_attention_dim,
156
+ num_attention_heads=num_attention_heads,
157
+ dual_cross_attention=dual_cross_attention,
158
+ use_linear_projection=use_linear_projection,
159
+ only_cross_attention=only_cross_attention,
160
+ upcast_attention=upcast_attention,
161
+ resnet_time_scale_shift=resnet_time_scale_shift,
162
+ num_views=num_views,
163
+ joint_attention=joint_attention,
164
+ joint_attention_twice=joint_attention_twice,
165
+ multiview_attention=multiview_attention,
166
+ cross_domain_attention=cross_domain_attention
167
+ )
168
+ elif down_block_type == "SimpleCrossAttnDownBlock2D":
169
+ if cross_attention_dim is None:
170
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")
171
+ return SimpleCrossAttnDownBlock2D(
172
+ num_layers=num_layers,
173
+ in_channels=in_channels,
174
+ out_channels=out_channels,
175
+ temb_channels=temb_channels,
176
+ add_downsample=add_downsample,
177
+ resnet_eps=resnet_eps,
178
+ resnet_act_fn=resnet_act_fn,
179
+ resnet_groups=resnet_groups,
180
+ cross_attention_dim=cross_attention_dim,
181
+ attention_head_dim=attention_head_dim,
182
+ resnet_time_scale_shift=resnet_time_scale_shift,
183
+ skip_time_act=resnet_skip_time_act,
184
+ output_scale_factor=resnet_out_scale_factor,
185
+ only_cross_attention=only_cross_attention,
186
+ cross_attention_norm=cross_attention_norm,
187
+ )
188
+ elif down_block_type == "SkipDownBlock2D":
189
+ return SkipDownBlock2D(
190
+ num_layers=num_layers,
191
+ in_channels=in_channels,
192
+ out_channels=out_channels,
193
+ temb_channels=temb_channels,
194
+ add_downsample=add_downsample,
195
+ resnet_eps=resnet_eps,
196
+ resnet_act_fn=resnet_act_fn,
197
+ downsample_padding=downsample_padding,
198
+ resnet_time_scale_shift=resnet_time_scale_shift,
199
+ )
200
+ elif down_block_type == "AttnSkipDownBlock2D":
201
+ return AttnSkipDownBlock2D(
202
+ num_layers=num_layers,
203
+ in_channels=in_channels,
204
+ out_channels=out_channels,
205
+ temb_channels=temb_channels,
206
+ add_downsample=add_downsample,
207
+ resnet_eps=resnet_eps,
208
+ resnet_act_fn=resnet_act_fn,
209
+ attention_head_dim=attention_head_dim,
210
+ resnet_time_scale_shift=resnet_time_scale_shift,
211
+ )
212
+ elif down_block_type == "DownEncoderBlock2D":
213
+ return DownEncoderBlock2D(
214
+ num_layers=num_layers,
215
+ in_channels=in_channels,
216
+ out_channels=out_channels,
217
+ add_downsample=add_downsample,
218
+ resnet_eps=resnet_eps,
219
+ resnet_act_fn=resnet_act_fn,
220
+ resnet_groups=resnet_groups,
221
+ downsample_padding=downsample_padding,
222
+ resnet_time_scale_shift=resnet_time_scale_shift,
223
+ )
224
+ elif down_block_type == "AttnDownEncoderBlock2D":
225
+ return AttnDownEncoderBlock2D(
226
+ num_layers=num_layers,
227
+ in_channels=in_channels,
228
+ out_channels=out_channels,
229
+ add_downsample=add_downsample,
230
+ resnet_eps=resnet_eps,
231
+ resnet_act_fn=resnet_act_fn,
232
+ resnet_groups=resnet_groups,
233
+ downsample_padding=downsample_padding,
234
+ attention_head_dim=attention_head_dim,
235
+ resnet_time_scale_shift=resnet_time_scale_shift,
236
+ )
237
+ elif down_block_type == "KDownBlock2D":
238
+ return KDownBlock2D(
239
+ num_layers=num_layers,
240
+ in_channels=in_channels,
241
+ out_channels=out_channels,
242
+ temb_channels=temb_channels,
243
+ add_downsample=add_downsample,
244
+ resnet_eps=resnet_eps,
245
+ resnet_act_fn=resnet_act_fn,
246
+ )
247
+ elif down_block_type == "KCrossAttnDownBlock2D":
248
+ return KCrossAttnDownBlock2D(
249
+ num_layers=num_layers,
250
+ in_channels=in_channels,
251
+ out_channels=out_channels,
252
+ temb_channels=temb_channels,
253
+ add_downsample=add_downsample,
254
+ resnet_eps=resnet_eps,
255
+ resnet_act_fn=resnet_act_fn,
256
+ cross_attention_dim=cross_attention_dim,
257
+ attention_head_dim=attention_head_dim,
258
+ add_self_attention=True if not add_downsample else False,
259
+ )
260
+ raise ValueError(f"{down_block_type} does not exist.")
261
+
262
+
263
+ def get_up_block(
264
+ up_block_type,
265
+ num_layers,
266
+ in_channels,
267
+ out_channels,
268
+ prev_output_channel,
269
+ temb_channels,
270
+ add_upsample,
271
+ resnet_eps,
272
+ resnet_act_fn,
273
+ transformer_layers_per_block=1,
274
+ num_attention_heads=None,
275
+ resnet_groups=None,
276
+ cross_attention_dim=None,
277
+ dual_cross_attention=False,
278
+ use_linear_projection=False,
279
+ only_cross_attention=False,
280
+ upcast_attention=False,
281
+ resnet_time_scale_shift="default",
282
+ resnet_skip_time_act=False,
283
+ resnet_out_scale_factor=1.0,
284
+ cross_attention_norm=None,
285
+ attention_head_dim=None,
286
+ upsample_type=None,
287
+ num_views=1,
288
+ joint_attention: bool = False,
289
+ joint_attention_twice: bool = False,
290
+ multiview_attention: bool = True,
291
+ cross_domain_attention: bool=False
292
+ ):
293
+ # If attn head dim is not defined, we default it to the number of heads
294
+ if attention_head_dim is None:
295
+ logger.warn(
296
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
297
+ )
298
+ attention_head_dim = num_attention_heads
299
+
300
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
301
+ if up_block_type == "UpBlock2D":
302
+ return UpBlock2D(
303
+ num_layers=num_layers,
304
+ in_channels=in_channels,
305
+ out_channels=out_channels,
306
+ prev_output_channel=prev_output_channel,
307
+ temb_channels=temb_channels,
308
+ add_upsample=add_upsample,
309
+ resnet_eps=resnet_eps,
310
+ resnet_act_fn=resnet_act_fn,
311
+ resnet_groups=resnet_groups,
312
+ resnet_time_scale_shift=resnet_time_scale_shift,
313
+ )
314
+ elif up_block_type == "ResnetUpsampleBlock2D":
315
+ return ResnetUpsampleBlock2D(
316
+ num_layers=num_layers,
317
+ in_channels=in_channels,
318
+ out_channels=out_channels,
319
+ prev_output_channel=prev_output_channel,
320
+ temb_channels=temb_channels,
321
+ add_upsample=add_upsample,
322
+ resnet_eps=resnet_eps,
323
+ resnet_act_fn=resnet_act_fn,
324
+ resnet_groups=resnet_groups,
325
+ resnet_time_scale_shift=resnet_time_scale_shift,
326
+ skip_time_act=resnet_skip_time_act,
327
+ output_scale_factor=resnet_out_scale_factor,
328
+ )
329
+ elif up_block_type == "CrossAttnUpBlock2D":
330
+ if cross_attention_dim is None:
331
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
332
+ return CrossAttnUpBlock2D(
333
+ num_layers=num_layers,
334
+ transformer_layers_per_block=transformer_layers_per_block,
335
+ in_channels=in_channels,
336
+ out_channels=out_channels,
337
+ prev_output_channel=prev_output_channel,
338
+ temb_channels=temb_channels,
339
+ add_upsample=add_upsample,
340
+ resnet_eps=resnet_eps,
341
+ resnet_act_fn=resnet_act_fn,
342
+ resnet_groups=resnet_groups,
343
+ cross_attention_dim=cross_attention_dim,
344
+ num_attention_heads=num_attention_heads,
345
+ dual_cross_attention=dual_cross_attention,
346
+ use_linear_projection=use_linear_projection,
347
+ only_cross_attention=only_cross_attention,
348
+ upcast_attention=upcast_attention,
349
+ resnet_time_scale_shift=resnet_time_scale_shift,
350
+ )
351
+ # custom MV2D attention block
352
+ elif up_block_type == "CrossAttnUpBlockMV2D":
353
+ if cross_attention_dim is None:
354
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMV2D")
355
+ return CrossAttnUpBlockMV2D(
356
+ num_layers=num_layers,
357
+ transformer_layers_per_block=transformer_layers_per_block,
358
+ in_channels=in_channels,
359
+ out_channels=out_channels,
360
+ prev_output_channel=prev_output_channel,
361
+ temb_channels=temb_channels,
362
+ add_upsample=add_upsample,
363
+ resnet_eps=resnet_eps,
364
+ resnet_act_fn=resnet_act_fn,
365
+ resnet_groups=resnet_groups,
366
+ cross_attention_dim=cross_attention_dim,
367
+ num_attention_heads=num_attention_heads,
368
+ dual_cross_attention=dual_cross_attention,
369
+ use_linear_projection=use_linear_projection,
370
+ only_cross_attention=only_cross_attention,
371
+ upcast_attention=upcast_attention,
372
+ resnet_time_scale_shift=resnet_time_scale_shift,
373
+ num_views=num_views,
374
+ joint_attention=joint_attention,
375
+ joint_attention_twice=joint_attention_twice,
376
+ multiview_attention=multiview_attention,
377
+ cross_domain_attention=cross_domain_attention
378
+ )
379
+ elif up_block_type == "SimpleCrossAttnUpBlock2D":
380
+ if cross_attention_dim is None:
381
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D")
382
+ return SimpleCrossAttnUpBlock2D(
383
+ num_layers=num_layers,
384
+ in_channels=in_channels,
385
+ out_channels=out_channels,
386
+ prev_output_channel=prev_output_channel,
387
+ temb_channels=temb_channels,
388
+ add_upsample=add_upsample,
389
+ resnet_eps=resnet_eps,
390
+ resnet_act_fn=resnet_act_fn,
391
+ resnet_groups=resnet_groups,
392
+ cross_attention_dim=cross_attention_dim,
393
+ attention_head_dim=attention_head_dim,
394
+ resnet_time_scale_shift=resnet_time_scale_shift,
395
+ skip_time_act=resnet_skip_time_act,
396
+ output_scale_factor=resnet_out_scale_factor,
397
+ only_cross_attention=only_cross_attention,
398
+ cross_attention_norm=cross_attention_norm,
399
+ )
400
+ elif up_block_type == "AttnUpBlock2D":
401
+ if add_upsample is False:
402
+ upsample_type = None
403
+ else:
404
+ upsample_type = upsample_type or "conv" # default to 'conv'
405
+
406
+ return AttnUpBlock2D(
407
+ num_layers=num_layers,
408
+ in_channels=in_channels,
409
+ out_channels=out_channels,
410
+ prev_output_channel=prev_output_channel,
411
+ temb_channels=temb_channels,
412
+ resnet_eps=resnet_eps,
413
+ resnet_act_fn=resnet_act_fn,
414
+ resnet_groups=resnet_groups,
415
+ attention_head_dim=attention_head_dim,
416
+ resnet_time_scale_shift=resnet_time_scale_shift,
417
+ upsample_type=upsample_type,
418
+ )
419
+ elif up_block_type == "SkipUpBlock2D":
420
+ return SkipUpBlock2D(
421
+ num_layers=num_layers,
422
+ in_channels=in_channels,
423
+ out_channels=out_channels,
424
+ prev_output_channel=prev_output_channel,
425
+ temb_channels=temb_channels,
426
+ add_upsample=add_upsample,
427
+ resnet_eps=resnet_eps,
428
+ resnet_act_fn=resnet_act_fn,
429
+ resnet_time_scale_shift=resnet_time_scale_shift,
430
+ )
431
+ elif up_block_type == "AttnSkipUpBlock2D":
432
+ return AttnSkipUpBlock2D(
433
+ num_layers=num_layers,
434
+ in_channels=in_channels,
435
+ out_channels=out_channels,
436
+ prev_output_channel=prev_output_channel,
437
+ temb_channels=temb_channels,
438
+ add_upsample=add_upsample,
439
+ resnet_eps=resnet_eps,
440
+ resnet_act_fn=resnet_act_fn,
441
+ attention_head_dim=attention_head_dim,
442
+ resnet_time_scale_shift=resnet_time_scale_shift,
443
+ )
444
+ elif up_block_type == "UpDecoderBlock2D":
445
+ return UpDecoderBlock2D(
446
+ num_layers=num_layers,
447
+ in_channels=in_channels,
448
+ out_channels=out_channels,
449
+ add_upsample=add_upsample,
450
+ resnet_eps=resnet_eps,
451
+ resnet_act_fn=resnet_act_fn,
452
+ resnet_groups=resnet_groups,
453
+ resnet_time_scale_shift=resnet_time_scale_shift,
454
+ temb_channels=temb_channels,
455
+ )
456
+ elif up_block_type == "AttnUpDecoderBlock2D":
457
+ return AttnUpDecoderBlock2D(
458
+ num_layers=num_layers,
459
+ in_channels=in_channels,
460
+ out_channels=out_channels,
461
+ add_upsample=add_upsample,
462
+ resnet_eps=resnet_eps,
463
+ resnet_act_fn=resnet_act_fn,
464
+ resnet_groups=resnet_groups,
465
+ attention_head_dim=attention_head_dim,
466
+ resnet_time_scale_shift=resnet_time_scale_shift,
467
+ temb_channels=temb_channels,
468
+ )
469
+ elif up_block_type == "KUpBlock2D":
470
+ return KUpBlock2D(
471
+ num_layers=num_layers,
472
+ in_channels=in_channels,
473
+ out_channels=out_channels,
474
+ temb_channels=temb_channels,
475
+ add_upsample=add_upsample,
476
+ resnet_eps=resnet_eps,
477
+ resnet_act_fn=resnet_act_fn,
478
+ )
479
+ elif up_block_type == "KCrossAttnUpBlock2D":
480
+ return KCrossAttnUpBlock2D(
481
+ num_layers=num_layers,
482
+ in_channels=in_channels,
483
+ out_channels=out_channels,
484
+ temb_channels=temb_channels,
485
+ add_upsample=add_upsample,
486
+ resnet_eps=resnet_eps,
487
+ resnet_act_fn=resnet_act_fn,
488
+ cross_attention_dim=cross_attention_dim,
489
+ attention_head_dim=attention_head_dim,
490
+ )
491
+
492
+ raise ValueError(f"{up_block_type} does not exist.")
493
+
494
+
495
+ class UNetMidBlockMV2DCrossAttn(nn.Module):
496
+ def __init__(
497
+ self,
498
+ in_channels: int,
499
+ temb_channels: int,
500
+ dropout: float = 0.0,
501
+ num_layers: int = 1,
502
+ transformer_layers_per_block: int = 1,
503
+ resnet_eps: float = 1e-6,
504
+ resnet_time_scale_shift: str = "default",
505
+ resnet_act_fn: str = "swish",
506
+ resnet_groups: int = 32,
507
+ resnet_pre_norm: bool = True,
508
+ num_attention_heads=1,
509
+ output_scale_factor=1.0,
510
+ cross_attention_dim=1280,
511
+ dual_cross_attention=False,
512
+ use_linear_projection=False,
513
+ upcast_attention=False,
514
+ num_views: int = 1,
515
+ joint_attention: bool = False,
516
+ joint_attention_twice: bool = False,
517
+ multiview_attention: bool = True,
518
+ cross_domain_attention: bool=False
519
+ ):
520
+ super().__init__()
521
+
522
+ self.has_cross_attention = True
523
+ self.num_attention_heads = num_attention_heads
524
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
525
+
526
+ # there is always at least one resnet
527
+ resnets = [
528
+ ResnetBlock2D(
529
+ in_channels=in_channels,
530
+ out_channels=in_channels,
531
+ temb_channels=temb_channels,
532
+ eps=resnet_eps,
533
+ groups=resnet_groups,
534
+ dropout=dropout,
535
+ time_embedding_norm=resnet_time_scale_shift,
536
+ non_linearity=resnet_act_fn,
537
+ output_scale_factor=output_scale_factor,
538
+ pre_norm=resnet_pre_norm,
539
+ )
540
+ ]
541
+ attentions = []
542
+
543
+ for _ in range(num_layers):
544
+ if not dual_cross_attention:
545
+ attentions.append(
546
+ TransformerMV2DModel(
547
+ num_attention_heads,
548
+ in_channels // num_attention_heads,
549
+ in_channels=in_channels,
550
+ num_layers=transformer_layers_per_block,
551
+ cross_attention_dim=cross_attention_dim,
552
+ norm_num_groups=resnet_groups,
553
+ use_linear_projection=use_linear_projection,
554
+ upcast_attention=upcast_attention,
555
+ num_views=num_views,
556
+ joint_attention=joint_attention,
557
+ joint_attention_twice=joint_attention_twice,
558
+ multiview_attention=multiview_attention,
559
+ cross_domain_attention=cross_domain_attention
560
+ )
561
+ )
562
+ else:
563
+ raise NotImplementedError
564
+ resnets.append(
565
+ ResnetBlock2D(
566
+ in_channels=in_channels,
567
+ out_channels=in_channels,
568
+ temb_channels=temb_channels,
569
+ eps=resnet_eps,
570
+ groups=resnet_groups,
571
+ dropout=dropout,
572
+ time_embedding_norm=resnet_time_scale_shift,
573
+ non_linearity=resnet_act_fn,
574
+ output_scale_factor=output_scale_factor,
575
+ pre_norm=resnet_pre_norm,
576
+ )
577
+ )
578
+
579
+ self.attentions = nn.ModuleList(attentions)
580
+ self.resnets = nn.ModuleList(resnets)
581
+
582
+ def forward(
583
+ self,
584
+ hidden_states: torch.FloatTensor,
585
+ temb: Optional[torch.FloatTensor] = None,
586
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
587
+ attention_mask: Optional[torch.FloatTensor] = None,
588
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
589
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
590
+ ) -> torch.FloatTensor:
591
+ hidden_states = self.resnets[0](hidden_states, temb)
592
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
593
+ hidden_states = attn(
594
+ hidden_states,
595
+ encoder_hidden_states=encoder_hidden_states,
596
+ cross_attention_kwargs=cross_attention_kwargs,
597
+ attention_mask=attention_mask,
598
+ encoder_attention_mask=encoder_attention_mask,
599
+ return_dict=False,
600
+ )[0]
601
+ hidden_states = resnet(hidden_states, temb)
602
+
603
+ return hidden_states
604
+
605
+
606
+ class CrossAttnUpBlockMV2D(nn.Module):
607
+ def __init__(
608
+ self,
609
+ in_channels: int,
610
+ out_channels: int,
611
+ prev_output_channel: int,
612
+ temb_channels: int,
613
+ dropout: float = 0.0,
614
+ num_layers: int = 1,
615
+ transformer_layers_per_block: int = 1,
616
+ resnet_eps: float = 1e-6,
617
+ resnet_time_scale_shift: str = "default",
618
+ resnet_act_fn: str = "swish",
619
+ resnet_groups: int = 32,
620
+ resnet_pre_norm: bool = True,
621
+ num_attention_heads=1,
622
+ cross_attention_dim=1280,
623
+ output_scale_factor=1.0,
624
+ add_upsample=True,
625
+ dual_cross_attention=False,
626
+ use_linear_projection=False,
627
+ only_cross_attention=False,
628
+ upcast_attention=False,
629
+ num_views: int = 1,
630
+ joint_attention: bool = False,
631
+ joint_attention_twice: bool = False,
632
+ multiview_attention: bool = True,
633
+ cross_domain_attention: bool=False
634
+ ):
635
+ super().__init__()
636
+ resnets = []
637
+ attentions = []
638
+
639
+ self.has_cross_attention = True
640
+ self.num_attention_heads = num_attention_heads
641
+
642
+ for i in range(num_layers):
643
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
644
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
645
+
646
+ resnets.append(
647
+ ResnetBlock2D(
648
+ in_channels=resnet_in_channels + res_skip_channels,
649
+ out_channels=out_channels,
650
+ temb_channels=temb_channels,
651
+ eps=resnet_eps,
652
+ groups=resnet_groups,
653
+ dropout=dropout,
654
+ time_embedding_norm=resnet_time_scale_shift,
655
+ non_linearity=resnet_act_fn,
656
+ output_scale_factor=output_scale_factor,
657
+ pre_norm=resnet_pre_norm,
658
+ )
659
+ )
660
+ if not dual_cross_attention:
661
+ attentions.append(
662
+ TransformerMV2DModel(
663
+ num_attention_heads,
664
+ out_channels // num_attention_heads,
665
+ in_channels=out_channels,
666
+ num_layers=transformer_layers_per_block,
667
+ cross_attention_dim=cross_attention_dim,
668
+ norm_num_groups=resnet_groups,
669
+ use_linear_projection=use_linear_projection,
670
+ only_cross_attention=only_cross_attention,
671
+ upcast_attention=upcast_attention,
672
+ num_views=num_views,
673
+ joint_attention=joint_attention,
674
+ joint_attention_twice=joint_attention_twice,
675
+ multiview_attention=multiview_attention,
676
+ cross_domain_attention=cross_domain_attention
677
+ )
678
+ )
679
+ else:
680
+ raise NotImplementedError
681
+ self.attentions = nn.ModuleList(attentions)
682
+ self.resnets = nn.ModuleList(resnets)
683
+
684
+ if add_upsample:
685
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
686
+ else:
687
+ self.upsamplers = None
688
+ if num_views == 4:
689
+ self.gradient_checkpointing = False
690
+ else:
691
+ self.gradient_checkpointing = False
692
+
693
+ def forward(
694
+ self,
695
+ hidden_states: torch.FloatTensor,
696
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
697
+ temb: Optional[torch.FloatTensor] = None,
698
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
699
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
700
+ upsample_size: Optional[int] = None,
701
+ attention_mask: Optional[torch.FloatTensor] = None,
702
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
703
+ ):
704
+ for resnet, attn in zip(self.resnets, self.attentions):
705
+ # pop res hidden states
706
+ res_hidden_states = res_hidden_states_tuple[-1]
707
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
708
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
709
+
710
+ if self.training and self.gradient_checkpointing:
711
+
712
+ def create_custom_forward(module, return_dict=None):
713
+ def custom_forward(*inputs):
714
+ if return_dict is not None:
715
+ return module(*inputs, return_dict=return_dict)
716
+ else:
717
+ return module(*inputs)
718
+
719
+ return custom_forward
720
+
721
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
722
+ hidden_states = torch.utils.checkpoint.checkpoint(
723
+ create_custom_forward(resnet),
724
+ hidden_states,
725
+ temb,
726
+ **ckpt_kwargs,
727
+ )
728
+ hidden_states = torch.utils.checkpoint.checkpoint(
729
+ create_custom_forward(attn, return_dict=False),
730
+ hidden_states,
731
+ encoder_hidden_states,
732
+ None, # timestep
733
+ None, # class_labels
734
+ cross_attention_kwargs,
735
+ attention_mask,
736
+ encoder_attention_mask,
737
+ **ckpt_kwargs,
738
+ )[0]
739
+ # hidden_states = attn(
740
+ # hidden_states,
741
+ # encoder_hidden_states=encoder_hidden_states,
742
+ # cross_attention_kwargs=cross_attention_kwargs,
743
+ # attention_mask=attention_mask,
744
+ # encoder_attention_mask=encoder_attention_mask,
745
+ # return_dict=False,
746
+ # )[0]
747
+ else:
748
+ hidden_states = resnet(hidden_states, temb)
749
+ hidden_states = attn(
750
+ hidden_states,
751
+ encoder_hidden_states=encoder_hidden_states,
752
+ cross_attention_kwargs=cross_attention_kwargs,
753
+ attention_mask=attention_mask,
754
+ encoder_attention_mask=encoder_attention_mask,
755
+ return_dict=False,
756
+ )[0]
757
+
758
+ if self.upsamplers is not None:
759
+ for upsampler in self.upsamplers:
760
+ hidden_states = upsampler(hidden_states, upsample_size)
761
+
762
+ return hidden_states
763
+
764
+
765
+ class CrossAttnDownBlockMV2D(nn.Module):
766
+ def __init__(
767
+ self,
768
+ in_channels: int,
769
+ out_channels: int,
770
+ temb_channels: int,
771
+ dropout: float = 0.0,
772
+ num_layers: int = 1,
773
+ transformer_layers_per_block: int = 1,
774
+ resnet_eps: float = 1e-6,
775
+ resnet_time_scale_shift: str = "default",
776
+ resnet_act_fn: str = "swish",
777
+ resnet_groups: int = 32,
778
+ resnet_pre_norm: bool = True,
779
+ num_attention_heads=1,
780
+ cross_attention_dim=1280,
781
+ output_scale_factor=1.0,
782
+ downsample_padding=1,
783
+ add_downsample=True,
784
+ dual_cross_attention=False,
785
+ use_linear_projection=False,
786
+ only_cross_attention=False,
787
+ upcast_attention=False,
788
+ num_views: int = 1,
789
+ joint_attention: bool = False,
790
+ joint_attention_twice: bool = False,
791
+ multiview_attention: bool = True,
792
+ cross_domain_attention: bool=False
793
+ ):
794
+ super().__init__()
795
+ resnets = []
796
+ attentions = []
797
+
798
+ self.has_cross_attention = True
799
+ self.num_attention_heads = num_attention_heads
800
+
801
+ for i in range(num_layers):
802
+ in_channels = in_channels if i == 0 else out_channels
803
+ resnets.append(
804
+ ResnetBlock2D(
805
+ in_channels=in_channels,
806
+ out_channels=out_channels,
807
+ temb_channels=temb_channels,
808
+ eps=resnet_eps,
809
+ groups=resnet_groups,
810
+ dropout=dropout,
811
+ time_embedding_norm=resnet_time_scale_shift,
812
+ non_linearity=resnet_act_fn,
813
+ output_scale_factor=output_scale_factor,
814
+ pre_norm=resnet_pre_norm,
815
+ )
816
+ )
817
+ if not dual_cross_attention:
818
+ attentions.append(
819
+ TransformerMV2DModel(
820
+ num_attention_heads,
821
+ out_channels // num_attention_heads,
822
+ in_channels=out_channels,
823
+ num_layers=transformer_layers_per_block,
824
+ cross_attention_dim=cross_attention_dim,
825
+ norm_num_groups=resnet_groups,
826
+ use_linear_projection=use_linear_projection,
827
+ only_cross_attention=only_cross_attention,
828
+ upcast_attention=upcast_attention,
829
+ num_views=num_views,
830
+ joint_attention=joint_attention,
831
+ joint_attention_twice=joint_attention_twice,
832
+ multiview_attention=multiview_attention,
833
+ cross_domain_attention=cross_domain_attention
834
+ )
835
+ )
836
+ else:
837
+ raise NotImplementedError
838
+ self.attentions = nn.ModuleList(attentions)
839
+ self.resnets = nn.ModuleList(resnets)
840
+
841
+ if add_downsample:
842
+ self.downsamplers = nn.ModuleList(
843
+ [
844
+ Downsample2D(
845
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
846
+ )
847
+ ]
848
+ )
849
+ else:
850
+ self.downsamplers = None
851
+ if num_views == 4:
852
+ self.gradient_checkpointing = False
853
+ else:
854
+ self.gradient_checkpointing = False
855
+
856
+ def forward(
857
+ self,
858
+ hidden_states: torch.FloatTensor,
859
+ temb: Optional[torch.FloatTensor] = None,
860
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
861
+ attention_mask: Optional[torch.FloatTensor] = None,
862
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
863
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
864
+ additional_residuals=None,
865
+ ):
866
+ output_states = ()
867
+
868
+ blocks = list(zip(self.resnets, self.attentions))
869
+
870
+ for i, (resnet, attn) in enumerate(blocks):
871
+ if self.training and self.gradient_checkpointing:
872
+
873
+ def create_custom_forward(module, return_dict=None):
874
+ def custom_forward(*inputs):
875
+ if return_dict is not None:
876
+ return module(*inputs, return_dict=return_dict)
877
+ else:
878
+ return module(*inputs)
879
+
880
+ return custom_forward
881
+
882
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
883
+ hidden_states = torch.utils.checkpoint.checkpoint(
884
+ create_custom_forward(resnet),
885
+ hidden_states,
886
+ temb,
887
+ **ckpt_kwargs,
888
+ )
889
+ hidden_states = torch.utils.checkpoint.checkpoint(
890
+ create_custom_forward(attn, return_dict=False),
891
+ hidden_states,
892
+ encoder_hidden_states,
893
+ None, # timestep
894
+ None, # class_labels
895
+ cross_attention_kwargs,
896
+ attention_mask,
897
+ encoder_attention_mask,
898
+ **ckpt_kwargs,
899
+ )[0]
900
+ else:
901
+ # import ipdb
902
+ # ipdb.set_trace()
903
+ hidden_states = resnet(hidden_states, temb)
904
+ hidden_states = attn(
905
+ hidden_states,
906
+ encoder_hidden_states=encoder_hidden_states,
907
+ cross_attention_kwargs=cross_attention_kwargs,
908
+ attention_mask=attention_mask,
909
+ encoder_attention_mask=encoder_attention_mask,
910
+ return_dict=False,
911
+ )[0]
912
+
913
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
914
+ if i == len(blocks) - 1 and additional_residuals is not None:
915
+ hidden_states = hidden_states + additional_residuals
916
+
917
+ output_states = output_states + (hidden_states,)
918
+
919
+ if self.downsamplers is not None:
920
+ for downsampler in self.downsamplers:
921
+ hidden_states = downsampler(hidden_states)
922
+
923
+ output_states = output_states + (hidden_states,)
924
+
925
+ return hidden_states, output_states
926
+
2D_Stage/tuneavideo/models/unet_mv2d_condition.py ADDED
@@ -0,0 +1,1509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+ import os
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+ from einops import rearrange
22
+
23
+
24
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+ from diffusers.loaders import UNet2DConditionLoadersMixin
26
+ from diffusers.utils import BaseOutput, logging
27
+ from diffusers.models.activations import get_activation
28
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
29
+ from diffusers.models.embeddings import (
30
+ GaussianFourierProjection,
31
+ ImageHintTimeEmbedding,
32
+ ImageProjection,
33
+ ImageTimeEmbedding,
34
+ TextImageProjection,
35
+ TextImageTimeEmbedding,
36
+ TextTimeEmbedding,
37
+ TimestepEmbedding,
38
+ Timesteps,
39
+ )
40
+ from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model
41
+ from diffusers.models.unet_2d_blocks import (
42
+ CrossAttnDownBlock2D,
43
+ CrossAttnUpBlock2D,
44
+ DownBlock2D,
45
+ UNetMidBlock2DCrossAttn,
46
+ UNetMidBlock2DSimpleCrossAttn,
47
+ UpBlock2D,
48
+ )
49
+ from diffusers.utils import (
50
+ CONFIG_NAME,
51
+ DIFFUSERS_CACHE,
52
+ FLAX_WEIGHTS_NAME,
53
+ HF_HUB_OFFLINE,
54
+ SAFETENSORS_WEIGHTS_NAME,
55
+ WEIGHTS_NAME,
56
+ _add_variant,
57
+ _get_model_file,
58
+ deprecate,
59
+ is_accelerate_available,
60
+ is_torch_version,
61
+ logging,
62
+ )
63
+ from diffusers import __version__
64
+ from tuneavideo.models.unet_mv2d_blocks import (
65
+ CrossAttnDownBlockMV2D,
66
+ CrossAttnUpBlockMV2D,
67
+ UNetMidBlockMV2DCrossAttn,
68
+ get_down_block,
69
+ get_up_block,
70
+ )
71
+ from diffusers.models.attention_processor import Attention, AttnProcessor
72
+ from diffusers.utils.import_utils import is_xformers_available
73
+ from tuneavideo.models.transformer_mv2d import XFormersMVAttnProcessor, MVAttnProcessor
74
+ from tuneavideo.models.refunet import ReferenceOnlyAttnProc
75
+
76
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
77
+
78
+
79
+ @dataclass
80
+ class UNetMV2DConditionOutput(BaseOutput):
81
+ """
82
+ The output of [`UNet2DConditionModel`].
83
+
84
+ Args:
85
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
86
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
87
+ """
88
+
89
+ sample: torch.FloatTensor = None
90
+
91
+ class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
92
+ r"""
93
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
94
+ shaped output.
95
+
96
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
97
+ for all models (such as downloading or saving).
98
+
99
+ Parameters:
100
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
101
+ Height and width of input/output sample.
102
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
103
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
104
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
105
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
106
+ Whether to flip the sin to cos in the time embedding.
107
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
108
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
109
+ The tuple of downsample blocks to use.
110
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
111
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
112
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
113
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
114
+ The tuple of upsample blocks to use.
115
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
116
+ Whether to include self-attention in the basic transformer blocks, see
117
+ [`~models.attention.BasicTransformerBlock`].
118
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
119
+ The tuple of output channels for each block.
120
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
121
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
122
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
123
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
124
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
125
+ If `None`, normalization and activation layers is skipped in post-processing.
126
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
127
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
128
+ The dimension of the cross attention features.
129
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
130
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
131
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
132
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
133
+ encoder_hid_dim (`int`, *optional*, defaults to None):
134
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
135
+ dimension to `cross_attention_dim`.
136
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
137
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
138
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
139
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
140
+ num_attention_heads (`int`, *optional*):
141
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
142
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
143
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
144
+ class_embed_type (`str`, *optional*, defaults to `None`):
145
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
146
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
147
+ addition_embed_type (`str`, *optional*, defaults to `None`):
148
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
149
+ "text". "text" will use the `TextTimeEmbedding` layer.
150
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
151
+ Dimension for the timestep embeddings.
152
+ num_class_embeds (`int`, *optional*, defaults to `None`):
153
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
154
+ class conditioning with `class_embed_type` equal to `None`.
155
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
156
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
157
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
158
+ An optional override for the dimension of the projected time embedding.
159
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
160
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
161
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
162
+ timestep_post_act (`str`, *optional*, defaults to `None`):
163
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
164
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
165
+ The dimension of `cond_proj` layer in the timestep embedding.
166
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
167
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
168
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
169
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
170
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
171
+ embeddings with the class embeddings.
172
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
173
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
174
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
175
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
176
+ otherwise.
177
+ """
178
+
179
+ _supports_gradient_checkpointing = True
180
+
181
+ @register_to_config
182
+ def __init__(
183
+ self,
184
+ sample_size: Optional[int] = None,
185
+ in_channels: int = 4,
186
+ out_channels: int = 4,
187
+ center_input_sample: bool = False,
188
+ flip_sin_to_cos: bool = True,
189
+ freq_shift: int = 0,
190
+ down_block_types: Tuple[str] = (
191
+ "CrossAttnDownBlockMV2D",
192
+ "CrossAttnDownBlockMV2D",
193
+ "CrossAttnDownBlockMV2D",
194
+ "DownBlock2D",
195
+ ),
196
+ mid_block_type: Optional[str] = "UNetMidBlockMV2DCrossAttn",
197
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D"),
198
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
199
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
200
+ layers_per_block: Union[int, Tuple[int]] = 2,
201
+ downsample_padding: int = 1,
202
+ mid_block_scale_factor: float = 1,
203
+ act_fn: str = "silu",
204
+ norm_num_groups: Optional[int] = 32,
205
+ norm_eps: float = 1e-5,
206
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
207
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
208
+ encoder_hid_dim: Optional[int] = None,
209
+ encoder_hid_dim_type: Optional[str] = None,
210
+ attention_head_dim: Union[int, Tuple[int]] = 8,
211
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
212
+ dual_cross_attention: bool = False,
213
+ use_linear_projection: bool = False,
214
+ class_embed_type: Optional[str] = None,
215
+ addition_embed_type: Optional[str] = None,
216
+ addition_time_embed_dim: Optional[int] = None,
217
+ num_class_embeds: Optional[int] = None,
218
+ upcast_attention: bool = False,
219
+ resnet_time_scale_shift: str = "default",
220
+ resnet_skip_time_act: bool = False,
221
+ resnet_out_scale_factor: int = 1.0,
222
+ time_embedding_type: str = "positional",
223
+ time_embedding_dim: Optional[int] = None,
224
+ time_embedding_act_fn: Optional[str] = None,
225
+ timestep_post_act: Optional[str] = None,
226
+ time_cond_proj_dim: Optional[int] = None,
227
+ conv_in_kernel: int = 3,
228
+ conv_out_kernel: int = 3,
229
+ projection_class_embeddings_input_dim: Optional[int] = None,
230
+ class_embeddings_concat: bool = False,
231
+ mid_block_only_cross_attention: Optional[bool] = None,
232
+ cross_attention_norm: Optional[str] = None,
233
+ addition_embed_type_num_heads=64,
234
+ num_views: int = 1,
235
+ joint_attention: bool = False,
236
+ joint_attention_twice: bool = False,
237
+ multiview_attention: bool = True,
238
+ cross_domain_attention: bool = False,
239
+ camera_input_dim: int = 12,
240
+ camera_hidden_dim: int = 320,
241
+ camera_output_dim: int = 1280,
242
+
243
+ ):
244
+ super().__init__()
245
+
246
+ self.sample_size = sample_size
247
+
248
+ if num_attention_heads is not None:
249
+ raise ValueError(
250
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
251
+ )
252
+
253
+ # If `num_attention_heads` is not defined (which is the case for most models)
254
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
255
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
256
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
257
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
258
+ # which is why we correct for the naming here.
259
+ num_attention_heads = num_attention_heads or attention_head_dim
260
+
261
+ # Check inputs
262
+ if len(down_block_types) != len(up_block_types):
263
+ raise ValueError(
264
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
265
+ )
266
+
267
+ if len(block_out_channels) != len(down_block_types):
268
+ raise ValueError(
269
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
270
+ )
271
+
272
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
273
+ raise ValueError(
274
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
275
+ )
276
+
277
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
278
+ raise ValueError(
279
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
280
+ )
281
+
282
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
283
+ raise ValueError(
284
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
285
+ )
286
+
287
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
288
+ raise ValueError(
289
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
290
+ )
291
+
292
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
293
+ raise ValueError(
294
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
295
+ )
296
+
297
+ # input
298
+ conv_in_padding = (conv_in_kernel - 1) // 2
299
+ self.conv_in = nn.Conv2d(
300
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
301
+ )
302
+
303
+ # time
304
+ if time_embedding_type == "fourier":
305
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
306
+ if time_embed_dim % 2 != 0:
307
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
308
+ self.time_proj = GaussianFourierProjection(
309
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
310
+ )
311
+ timestep_input_dim = time_embed_dim
312
+ elif time_embedding_type == "positional":
313
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
314
+
315
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
316
+ timestep_input_dim = block_out_channels[0]
317
+ else:
318
+ raise ValueError(
319
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
320
+ )
321
+
322
+ self.time_embedding = TimestepEmbedding(
323
+ timestep_input_dim,
324
+ time_embed_dim,
325
+ act_fn=act_fn,
326
+ post_act_fn=timestep_post_act,
327
+ cond_proj_dim=time_cond_proj_dim,
328
+ )
329
+
330
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
331
+ encoder_hid_dim_type = "text_proj"
332
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
333
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
334
+
335
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
336
+ raise ValueError(
337
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
338
+ )
339
+
340
+ if encoder_hid_dim_type == "text_proj":
341
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
342
+ elif encoder_hid_dim_type == "text_image_proj":
343
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
344
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
345
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
346
+ self.encoder_hid_proj = TextImageProjection(
347
+ text_embed_dim=encoder_hid_dim,
348
+ image_embed_dim=cross_attention_dim,
349
+ cross_attention_dim=cross_attention_dim,
350
+ )
351
+ elif encoder_hid_dim_type == "image_proj":
352
+ # Kandinsky 2.2
353
+ self.encoder_hid_proj = ImageProjection(
354
+ image_embed_dim=encoder_hid_dim,
355
+ cross_attention_dim=cross_attention_dim,
356
+ )
357
+ elif encoder_hid_dim_type is not None:
358
+ raise ValueError(
359
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
360
+ )
361
+ else:
362
+ self.encoder_hid_proj = None
363
+
364
+ # class embedding
365
+ if class_embed_type is None and num_class_embeds is not None:
366
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
367
+ elif class_embed_type == "timestep":
368
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
369
+ elif class_embed_type == "identity":
370
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
371
+ elif class_embed_type == "projection":
372
+ if projection_class_embeddings_input_dim is None:
373
+ raise ValueError(
374
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
375
+ )
376
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
377
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
378
+ # 2. it projects from an arbitrary input dimension.
379
+ #
380
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
381
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
382
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
383
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
384
+ elif class_embed_type == "simple_projection":
385
+ if projection_class_embeddings_input_dim is None:
386
+ raise ValueError(
387
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
388
+ )
389
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
390
+ else:
391
+ self.class_embedding = None
392
+
393
+ if addition_embed_type == "text":
394
+ if encoder_hid_dim is not None:
395
+ text_time_embedding_from_dim = encoder_hid_dim
396
+ else:
397
+ text_time_embedding_from_dim = cross_attention_dim
398
+
399
+ self.add_embedding = TextTimeEmbedding(
400
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
401
+ )
402
+ elif addition_embed_type == "text_image":
403
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
404
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
405
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
406
+ self.add_embedding = TextImageTimeEmbedding(
407
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
408
+ )
409
+ elif addition_embed_type == "text_time":
410
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
411
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
412
+ elif addition_embed_type == "image":
413
+ # Kandinsky 2.2
414
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
415
+ elif addition_embed_type == "image_hint":
416
+ # Kandinsky 2.2 ControlNet
417
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
418
+ elif addition_embed_type is not None:
419
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
420
+
421
+ if time_embedding_act_fn is None:
422
+ self.time_embed_act = None
423
+ else:
424
+ self.time_embed_act = get_activation(time_embedding_act_fn)
425
+
426
+ self.camera_embedding = nn.Sequential(
427
+ nn.Linear(camera_input_dim, time_embed_dim),
428
+ nn.SiLU(),
429
+ nn.Linear(time_embed_dim, time_embed_dim),
430
+ )
431
+
432
+ self.down_blocks = nn.ModuleList([])
433
+ self.up_blocks = nn.ModuleList([])
434
+
435
+ if isinstance(only_cross_attention, bool):
436
+ if mid_block_only_cross_attention is None:
437
+ mid_block_only_cross_attention = only_cross_attention
438
+
439
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
440
+
441
+ if mid_block_only_cross_attention is None:
442
+ mid_block_only_cross_attention = False
443
+
444
+ if isinstance(num_attention_heads, int):
445
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
446
+
447
+ if isinstance(attention_head_dim, int):
448
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
449
+
450
+ if isinstance(cross_attention_dim, int):
451
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
452
+
453
+ if isinstance(layers_per_block, int):
454
+ layers_per_block = [layers_per_block] * len(down_block_types)
455
+
456
+ if isinstance(transformer_layers_per_block, int):
457
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
458
+
459
+ if class_embeddings_concat:
460
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
461
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
462
+ # regular time embeddings
463
+ blocks_time_embed_dim = time_embed_dim * 2
464
+ else:
465
+ blocks_time_embed_dim = time_embed_dim
466
+
467
+ # down
468
+ output_channel = block_out_channels[0]
469
+ for i, down_block_type in enumerate(down_block_types):
470
+ input_channel = output_channel
471
+ output_channel = block_out_channels[i]
472
+ is_final_block = i == len(block_out_channels) - 1
473
+
474
+ down_block = get_down_block(
475
+ down_block_type,
476
+ num_layers=layers_per_block[i],
477
+ transformer_layers_per_block=transformer_layers_per_block[i],
478
+ in_channels=input_channel,
479
+ out_channels=output_channel,
480
+ temb_channels=blocks_time_embed_dim,
481
+ add_downsample=not is_final_block,
482
+ resnet_eps=norm_eps,
483
+ resnet_act_fn=act_fn,
484
+ resnet_groups=norm_num_groups,
485
+ cross_attention_dim=cross_attention_dim[i],
486
+ num_attention_heads=num_attention_heads[i],
487
+ downsample_padding=downsample_padding,
488
+ dual_cross_attention=dual_cross_attention,
489
+ use_linear_projection=use_linear_projection,
490
+ only_cross_attention=only_cross_attention[i],
491
+ upcast_attention=upcast_attention,
492
+ resnet_time_scale_shift=resnet_time_scale_shift,
493
+ resnet_skip_time_act=resnet_skip_time_act,
494
+ resnet_out_scale_factor=resnet_out_scale_factor,
495
+ cross_attention_norm=cross_attention_norm,
496
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
497
+ num_views=num_views,
498
+ joint_attention=joint_attention,
499
+ joint_attention_twice=joint_attention_twice,
500
+ multiview_attention=multiview_attention,
501
+ cross_domain_attention=cross_domain_attention
502
+ )
503
+ self.down_blocks.append(down_block)
504
+
505
+ # mid
506
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
507
+ self.mid_block = UNetMidBlock2DCrossAttn(
508
+ transformer_layers_per_block=transformer_layers_per_block[-1],
509
+ in_channels=block_out_channels[-1],
510
+ temb_channels=blocks_time_embed_dim,
511
+ resnet_eps=norm_eps,
512
+ resnet_act_fn=act_fn,
513
+ output_scale_factor=mid_block_scale_factor,
514
+ resnet_time_scale_shift=resnet_time_scale_shift,
515
+ cross_attention_dim=cross_attention_dim[-1],
516
+ num_attention_heads=num_attention_heads[-1],
517
+ resnet_groups=norm_num_groups,
518
+ dual_cross_attention=dual_cross_attention,
519
+ use_linear_projection=use_linear_projection,
520
+ upcast_attention=upcast_attention,
521
+ )
522
+ # custom MV2D attention block
523
+ elif mid_block_type == "UNetMidBlockMV2DCrossAttn":
524
+ self.mid_block = UNetMidBlockMV2DCrossAttn(
525
+ transformer_layers_per_block=transformer_layers_per_block[-1],
526
+ in_channels=block_out_channels[-1],
527
+ temb_channels=blocks_time_embed_dim,
528
+ resnet_eps=norm_eps,
529
+ resnet_act_fn=act_fn,
530
+ output_scale_factor=mid_block_scale_factor,
531
+ resnet_time_scale_shift=resnet_time_scale_shift,
532
+ cross_attention_dim=cross_attention_dim[-1],
533
+ num_attention_heads=num_attention_heads[-1],
534
+ resnet_groups=norm_num_groups,
535
+ dual_cross_attention=dual_cross_attention,
536
+ use_linear_projection=use_linear_projection,
537
+ upcast_attention=upcast_attention,
538
+ num_views=num_views,
539
+ joint_attention=joint_attention,
540
+ joint_attention_twice=joint_attention_twice,
541
+ multiview_attention=multiview_attention,
542
+ cross_domain_attention=cross_domain_attention
543
+ )
544
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
545
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
546
+ in_channels=block_out_channels[-1],
547
+ temb_channels=blocks_time_embed_dim,
548
+ resnet_eps=norm_eps,
549
+ resnet_act_fn=act_fn,
550
+ output_scale_factor=mid_block_scale_factor,
551
+ cross_attention_dim=cross_attention_dim[-1],
552
+ attention_head_dim=attention_head_dim[-1],
553
+ resnet_groups=norm_num_groups,
554
+ resnet_time_scale_shift=resnet_time_scale_shift,
555
+ skip_time_act=resnet_skip_time_act,
556
+ only_cross_attention=mid_block_only_cross_attention,
557
+ cross_attention_norm=cross_attention_norm,
558
+ )
559
+ elif mid_block_type is None:
560
+ self.mid_block = None
561
+ else:
562
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
563
+
564
+ # count how many layers upsample the images
565
+ self.num_upsamplers = 0
566
+
567
+ # up
568
+ reversed_block_out_channels = list(reversed(block_out_channels))
569
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
570
+ reversed_layers_per_block = list(reversed(layers_per_block))
571
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
572
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
573
+ only_cross_attention = list(reversed(only_cross_attention))
574
+
575
+ output_channel = reversed_block_out_channels[0]
576
+ for i, up_block_type in enumerate(up_block_types):
577
+ is_final_block = i == len(block_out_channels) - 1
578
+
579
+ prev_output_channel = output_channel
580
+ output_channel = reversed_block_out_channels[i]
581
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
582
+
583
+ # add upsample block for all BUT final layer
584
+ if not is_final_block:
585
+ add_upsample = True
586
+ self.num_upsamplers += 1
587
+ else:
588
+ add_upsample = False
589
+
590
+ up_block = get_up_block(
591
+ up_block_type,
592
+ num_layers=reversed_layers_per_block[i] + 1,
593
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
594
+ in_channels=input_channel,
595
+ out_channels=output_channel,
596
+ prev_output_channel=prev_output_channel,
597
+ temb_channels=blocks_time_embed_dim,
598
+ add_upsample=add_upsample,
599
+ resnet_eps=norm_eps,
600
+ resnet_act_fn=act_fn,
601
+ resnet_groups=norm_num_groups,
602
+ cross_attention_dim=reversed_cross_attention_dim[i],
603
+ num_attention_heads=reversed_num_attention_heads[i],
604
+ dual_cross_attention=dual_cross_attention,
605
+ use_linear_projection=use_linear_projection,
606
+ only_cross_attention=only_cross_attention[i],
607
+ upcast_attention=upcast_attention,
608
+ resnet_time_scale_shift=resnet_time_scale_shift,
609
+ resnet_skip_time_act=resnet_skip_time_act,
610
+ resnet_out_scale_factor=resnet_out_scale_factor,
611
+ cross_attention_norm=cross_attention_norm,
612
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
613
+ num_views=num_views,
614
+ joint_attention=joint_attention,
615
+ joint_attention_twice=joint_attention_twice,
616
+ multiview_attention=multiview_attention,
617
+ cross_domain_attention=cross_domain_attention
618
+ )
619
+ self.up_blocks.append(up_block)
620
+ prev_output_channel = output_channel
621
+
622
+ # out
623
+ if norm_num_groups is not None:
624
+ self.conv_norm_out = nn.GroupNorm(
625
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
626
+ )
627
+
628
+ self.conv_act = get_activation(act_fn)
629
+
630
+ else:
631
+ self.conv_norm_out = None
632
+ self.conv_act = None
633
+
634
+ conv_out_padding = (conv_out_kernel - 1) // 2
635
+ self.conv_out = nn.Conv2d(
636
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
637
+ )
638
+
639
+ @property
640
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
641
+ r"""
642
+ Returns:
643
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
644
+ indexed by its weight name.
645
+ """
646
+ # set recursively
647
+ processors = {}
648
+
649
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
650
+ if hasattr(module, "set_processor"):
651
+ processors[f"{name}.processor"] = module.processor
652
+
653
+ for sub_name, child in module.named_children():
654
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
655
+
656
+ return processors
657
+
658
+ for name, module in self.named_children():
659
+ fn_recursive_add_processors(name, module, processors)
660
+
661
+ return processors
662
+
663
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
664
+ r"""
665
+ Sets the attention processor to use to compute attention.
666
+
667
+ Parameters:
668
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
669
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
670
+ for **all** `Attention` layers.
671
+
672
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
673
+ processor. This is strongly recommended when setting trainable attention processors.
674
+
675
+ """
676
+ count = len(self.attn_processors.keys())
677
+
678
+ if isinstance(processor, dict) and len(processor) != count:
679
+ raise ValueError(
680
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
681
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
682
+ )
683
+
684
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
685
+ if hasattr(module, "set_processor"):
686
+ if not isinstance(processor, dict):
687
+ module.set_processor(processor)
688
+ else:
689
+ module.set_processor(processor.pop(f"{name}.processor"))
690
+
691
+ for sub_name, child in module.named_children():
692
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
693
+
694
+ for name, module in self.named_children():
695
+ fn_recursive_attn_processor(name, module, processor)
696
+
697
+ def set_default_attn_processor(self):
698
+ """
699
+ Disables custom attention processors and sets the default attention implementation.
700
+ """
701
+ self.set_attn_processor(AttnProcessor())
702
+
703
+ def set_attention_slice(self, slice_size):
704
+ r"""
705
+ Enable sliced attention computation.
706
+
707
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
708
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
709
+
710
+ Args:
711
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
712
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
713
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
714
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
715
+ must be a multiple of `slice_size`.
716
+ """
717
+ sliceable_head_dims = []
718
+
719
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
720
+ if hasattr(module, "set_attention_slice"):
721
+ sliceable_head_dims.append(module.sliceable_head_dim)
722
+
723
+ for child in module.children():
724
+ fn_recursive_retrieve_sliceable_dims(child)
725
+
726
+ # retrieve number of attention layers
727
+ for module in self.children():
728
+ fn_recursive_retrieve_sliceable_dims(module)
729
+
730
+ num_sliceable_layers = len(sliceable_head_dims)
731
+
732
+ if slice_size == "auto":
733
+ # half the attention head size is usually a good trade-off between
734
+ # speed and memory
735
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
736
+ elif slice_size == "max":
737
+ # make smallest slice possible
738
+ slice_size = num_sliceable_layers * [1]
739
+
740
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
741
+
742
+ if len(slice_size) != len(sliceable_head_dims):
743
+ raise ValueError(
744
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
745
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
746
+ )
747
+
748
+ for i in range(len(slice_size)):
749
+ size = slice_size[i]
750
+ dim = sliceable_head_dims[i]
751
+ if size is not None and size > dim:
752
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
753
+
754
+ # Recursively walk through all the children.
755
+ # Any children which exposes the set_attention_slice method
756
+ # gets the message
757
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
758
+ if hasattr(module, "set_attention_slice"):
759
+ module.set_attention_slice(slice_size.pop())
760
+
761
+ for child in module.children():
762
+ fn_recursive_set_attention_slice(child, slice_size)
763
+
764
+ reversed_slice_size = list(reversed(slice_size))
765
+ for module in self.children():
766
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
767
+
768
+ # def _set_gradient_checkpointing(self, module, value=False):
769
+ # if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)):
770
+ # module.gradient_checkpointing = value
771
+
772
+ def forward(
773
+ self,
774
+ sample: torch.FloatTensor,
775
+ timestep: Union[torch.Tensor, float, int],
776
+ encoder_hidden_states: torch.Tensor,
777
+ camera_matrixs: Optional[torch.Tensor] = None,
778
+ class_labels: Optional[torch.Tensor] = None,
779
+ timestep_cond: Optional[torch.Tensor] = None,
780
+ attention_mask: Optional[torch.Tensor] = None,
781
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
782
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
783
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
784
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
785
+ encoder_attention_mask: Optional[torch.Tensor] = None,
786
+ return_dict: bool = True,
787
+ ) -> Union[UNetMV2DConditionOutput, Tuple]:
788
+ r"""
789
+ The [`UNet2DConditionModel`] forward method.
790
+
791
+ Args:
792
+ sample (`torch.FloatTensor`):
793
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
794
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
795
+ encoder_hidden_states (`torch.FloatTensor`):
796
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
797
+ encoder_attention_mask (`torch.Tensor`):
798
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
799
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
800
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
801
+ return_dict (`bool`, *optional*, defaults to `True`):
802
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
803
+ tuple.
804
+ cross_attention_kwargs (`dict`, *optional*):
805
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
806
+ added_cond_kwargs: (`dict`, *optional*):
807
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
808
+ are passed along to the UNet blocks.
809
+
810
+ Returns:
811
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
812
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
813
+ a `tuple` is returned where the first element is the sample tensor.
814
+ """
815
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
816
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
817
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
818
+ # on the fly if necessary.
819
+ default_overall_up_factor = 2**self.num_upsamplers
820
+
821
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
822
+ forward_upsample_size = False
823
+ upsample_size = None
824
+
825
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
826
+ logger.info("Forward upsample size to force interpolation output size.")
827
+ forward_upsample_size = True
828
+
829
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
830
+ # expects mask of shape:
831
+ # [batch, key_tokens]
832
+ # adds singleton query_tokens dimension:
833
+ # [batch, 1, key_tokens]
834
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
835
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
836
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
837
+ if attention_mask is not None:
838
+ # assume that mask is expressed as:
839
+ # (1 = keep, 0 = discard)
840
+ # convert mask into a bias that can be added to attention scores:
841
+ # (keep = +0, discard = -10000.0)
842
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
843
+ attention_mask = attention_mask.unsqueeze(1)
844
+
845
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
846
+ if encoder_attention_mask is not None:
847
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
848
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
849
+
850
+ # 0. center input if necessary
851
+ if self.config.center_input_sample:
852
+ sample = 2 * sample - 1.0
853
+
854
+ # 1. time
855
+ timesteps = timestep
856
+ if not torch.is_tensor(timesteps):
857
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
858
+ # This would be a good case for the `match` statement (Python 3.10+)
859
+ is_mps = sample.device.type == "mps"
860
+ if isinstance(timestep, float):
861
+ dtype = torch.float32 if is_mps else torch.float64
862
+ else:
863
+ dtype = torch.int32 if is_mps else torch.int64
864
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
865
+ elif len(timesteps.shape) == 0:
866
+ timesteps = timesteps[None].to(sample.device)
867
+
868
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
869
+ timesteps = timesteps.expand(sample.shape[0])
870
+
871
+ t_emb = self.time_proj(timesteps)
872
+
873
+ # `Timesteps` does not contain any weights and will always return f32 tensors
874
+ # but time_embedding might actually be running in fp16. so we need to cast here.
875
+ # there might be better ways to encapsulate this.
876
+ t_emb = t_emb.to(dtype=sample.dtype)
877
+ emb = self.time_embedding(t_emb, timestep_cond)
878
+
879
+ # import pdb; pdb.set_trace()
880
+ if camera_matrixs is not None:
881
+ emb = torch.unsqueeze(emb, 1)
882
+ # came emb
883
+ cam_emb = self.camera_embedding(camera_matrixs)
884
+ # cam_emb = self.camera_embedding_2(cam_emb)
885
+ # import ipdb
886
+ # ipdb.set_trace()
887
+ emb = emb.repeat(1,cam_emb.shape[1],1) #torch.Size([32, 4, 1280])
888
+ emb = emb + cam_emb
889
+ emb = rearrange(emb, "b f c -> (b f) c", f=emb.shape[1])
890
+
891
+ aug_emb = None
892
+
893
+ if self.class_embedding is not None and class_labels is not None:
894
+ if class_labels is None:
895
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
896
+
897
+ if self.config.class_embed_type == "timestep":
898
+ class_labels = self.time_proj(class_labels)
899
+
900
+ # `Timesteps` does not contain any weights and will always return f32 tensors
901
+ # there might be better ways to encapsulate this.
902
+ class_labels = class_labels.to(dtype=sample.dtype)
903
+
904
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
905
+
906
+ if self.config.class_embeddings_concat:
907
+ emb = torch.cat([emb, class_emb], dim=-1)
908
+ else:
909
+ emb = emb + class_emb
910
+
911
+ if self.config.addition_embed_type == "text":
912
+ aug_emb = self.add_embedding(encoder_hidden_states)
913
+ elif self.config.addition_embed_type == "text_image":
914
+ # Kandinsky 2.1 - style
915
+ if "image_embeds" not in added_cond_kwargs:
916
+ raise ValueError(
917
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
918
+ )
919
+
920
+ image_embs = added_cond_kwargs.get("image_embeds")
921
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
922
+ aug_emb = self.add_embedding(text_embs, image_embs)
923
+ elif self.config.addition_embed_type == "text_time":
924
+ # SDXL - style
925
+ if "text_embeds" not in added_cond_kwargs:
926
+ raise ValueError(
927
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
928
+ )
929
+ text_embeds = added_cond_kwargs.get("text_embeds")
930
+ if "time_ids" not in added_cond_kwargs:
931
+ raise ValueError(
932
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
933
+ )
934
+ time_ids = added_cond_kwargs.get("time_ids")
935
+ time_embeds = self.add_time_proj(time_ids.flatten())
936
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
937
+
938
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
939
+ add_embeds = add_embeds.to(emb.dtype)
940
+ aug_emb = self.add_embedding(add_embeds)
941
+ elif self.config.addition_embed_type == "image":
942
+ # Kandinsky 2.2 - style
943
+ if "image_embeds" not in added_cond_kwargs:
944
+ raise ValueError(
945
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
946
+ )
947
+ image_embs = added_cond_kwargs.get("image_embeds")
948
+ aug_emb = self.add_embedding(image_embs)
949
+ elif self.config.addition_embed_type == "image_hint":
950
+ # Kandinsky 2.2 - style
951
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
952
+ raise ValueError(
953
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
954
+ )
955
+ image_embs = added_cond_kwargs.get("image_embeds")
956
+ hint = added_cond_kwargs.get("hint")
957
+ aug_emb, hint = self.add_embedding(image_embs, hint)
958
+ sample = torch.cat([sample, hint], dim=1)
959
+
960
+ emb = emb + aug_emb if aug_emb is not None else emb
961
+
962
+ if self.time_embed_act is not None:
963
+ emb = self.time_embed_act(emb)
964
+
965
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
966
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
967
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
968
+ # Kadinsky 2.1 - style
969
+ if "image_embeds" not in added_cond_kwargs:
970
+ raise ValueError(
971
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
972
+ )
973
+
974
+ image_embeds = added_cond_kwargs.get("image_embeds")
975
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
976
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
977
+ # Kandinsky 2.2 - style
978
+ if "image_embeds" not in added_cond_kwargs:
979
+ raise ValueError(
980
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
981
+ )
982
+ image_embeds = added_cond_kwargs.get("image_embeds")
983
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
984
+ # 2. pre-process
985
+ sample = rearrange(sample, "b c f h w -> (b f) c h w", f=sample.shape[2])
986
+ sample = self.conv_in(sample)
987
+ # 3. down
988
+
989
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
990
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
991
+
992
+ down_block_res_samples = (sample,)
993
+ for downsample_block in self.down_blocks:
994
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
995
+ # For t2i-adapter CrossAttnDownBlock2D
996
+ additional_residuals = {}
997
+ if is_adapter and len(down_block_additional_residuals) > 0:
998
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
999
+
1000
+ sample, res_samples = downsample_block(
1001
+ hidden_states=sample,
1002
+ temb=emb,
1003
+ encoder_hidden_states=encoder_hidden_states,
1004
+ attention_mask=attention_mask,
1005
+ cross_attention_kwargs=cross_attention_kwargs,
1006
+ encoder_attention_mask=encoder_attention_mask,
1007
+ **additional_residuals,
1008
+ )
1009
+ else:
1010
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1011
+
1012
+ if is_adapter and len(down_block_additional_residuals) > 0:
1013
+ sample += down_block_additional_residuals.pop(0)
1014
+
1015
+ down_block_res_samples += res_samples
1016
+
1017
+ if is_controlnet:
1018
+ new_down_block_res_samples = ()
1019
+
1020
+ for down_block_res_sample, down_block_additional_residual in zip(
1021
+ down_block_res_samples, down_block_additional_residuals
1022
+ ):
1023
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1024
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1025
+
1026
+ down_block_res_samples = new_down_block_res_samples
1027
+ # print("after down: ", sample.mean(), emb.mean())
1028
+ # 4. mid
1029
+ if self.mid_block is not None:
1030
+ sample = self.mid_block(
1031
+ sample,
1032
+ emb,
1033
+ encoder_hidden_states=encoder_hidden_states,
1034
+ attention_mask=attention_mask,
1035
+ cross_attention_kwargs=cross_attention_kwargs,
1036
+ encoder_attention_mask=encoder_attention_mask,
1037
+ )
1038
+
1039
+ if is_controlnet:
1040
+ sample = sample + mid_block_additional_residual
1041
+
1042
+ # 5. up
1043
+ for i, upsample_block in enumerate(self.up_blocks):
1044
+ is_final_block = i == len(self.up_blocks) - 1
1045
+
1046
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1047
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1048
+
1049
+ # if we have not reached the final block and need to forward the
1050
+ # upsample size, we do it here
1051
+ if not is_final_block and forward_upsample_size:
1052
+ upsample_size = down_block_res_samples[-1].shape[2:]
1053
+
1054
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1055
+ sample = upsample_block(
1056
+ hidden_states=sample,
1057
+ temb=emb,
1058
+ res_hidden_states_tuple=res_samples,
1059
+ encoder_hidden_states=encoder_hidden_states,
1060
+ cross_attention_kwargs=cross_attention_kwargs,
1061
+ upsample_size=upsample_size,
1062
+ attention_mask=attention_mask,
1063
+ encoder_attention_mask=encoder_attention_mask,
1064
+ )
1065
+ else:
1066
+ sample = upsample_block(
1067
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1068
+ )
1069
+
1070
+ # 6. post-process
1071
+ if self.conv_norm_out:
1072
+ sample = self.conv_norm_out(sample)
1073
+ sample = self.conv_act(sample)
1074
+ sample = self.conv_out(sample)
1075
+
1076
+ if not return_dict:
1077
+ return (sample,)
1078
+
1079
+ return UNetMV2DConditionOutput(sample=sample)
1080
+
1081
+ @classmethod
1082
+ def from_pretrained_2d(
1083
+ cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
1084
+ camera_embedding_type: str, num_views: int, sample_size: int,
1085
+ zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False,
1086
+ projection_class_embeddings_input_dim: int=6, joint_attention: bool = False,
1087
+ joint_attention_twice: bool = False, multiview_attention: bool = True,
1088
+ cross_domain_attention: bool = False,
1089
+ in_channels: int = 8, out_channels: int = 4, local_crossattn=False,
1090
+ **kwargs
1091
+ ):
1092
+ r"""
1093
+ Instantiate a pretrained PyTorch model from a pretrained model configuration.
1094
+
1095
+ The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
1096
+ train the model, set it back in training mode with `model.train()`.
1097
+
1098
+ Parameters:
1099
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
1100
+ Can be either:
1101
+
1102
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
1103
+ the Hub.
1104
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
1105
+ with [`~ModelMixin.save_pretrained`].
1106
+
1107
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
1108
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
1109
+ is not used.
1110
+ torch_dtype (`str` or `torch.dtype`, *optional*):
1111
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
1112
+ dtype is automatically derived from the model's weights.
1113
+ force_download (`bool`, *optional*, defaults to `False`):
1114
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1115
+ cached versions if they exist.
1116
+ resume_download (`bool`, *optional*, defaults to `False`):
1117
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
1118
+ incompletely downloaded files are deleted.
1119
+ proxies (`Dict[str, str]`, *optional*):
1120
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1121
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1122
+ output_loading_info (`bool`, *optional*, defaults to `False`):
1123
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
1124
+ local_files_only(`bool`, *optional*, defaults to `False`):
1125
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
1126
+ won't be downloaded from the Hub.
1127
+ use_auth_token (`str` or *bool*, *optional*):
1128
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
1129
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
1130
+ revision (`str`, *optional*, defaults to `"main"`):
1131
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
1132
+ allowed by Git.
1133
+ from_flax (`bool`, *optional*, defaults to `False`):
1134
+ Load the model weights from a Flax checkpoint save file.
1135
+ subfolder (`str`, *optional*, defaults to `""`):
1136
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
1137
+ mirror (`str`, *optional*):
1138
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
1139
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
1140
+ information.
1141
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
1142
+ A map that specifies where each submodule should go. It doesn't need to be defined for each
1143
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
1144
+ same device.
1145
+
1146
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
1147
+ more information about each option see [designing a device
1148
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
1149
+ max_memory (`Dict`, *optional*):
1150
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
1151
+ each GPU and the available CPU RAM if unset.
1152
+ offload_folder (`str` or `os.PathLike`, *optional*):
1153
+ The path to offload weights if `device_map` contains the value `"disk"`.
1154
+ offload_state_dict (`bool`, *optional*):
1155
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
1156
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
1157
+ when there is some disk offload.
1158
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
1159
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
1160
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
1161
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
1162
+ argument to `True` will raise an error.
1163
+ variant (`str`, *optional*):
1164
+ Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
1165
+ loading `from_flax`.
1166
+ use_safetensors (`bool`, *optional*, defaults to `None`):
1167
+ If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
1168
+ `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
1169
+ weights. If set to `False`, `safetensors` weights are not loaded.
1170
+
1171
+ <Tip>
1172
+
1173
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
1174
+ `huggingface-cli login`. You can also activate the special
1175
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
1176
+ firewalled environment.
1177
+
1178
+ </Tip>
1179
+
1180
+ Example:
1181
+
1182
+ ```py
1183
+ from diffusers import UNet2DConditionModel
1184
+
1185
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
1186
+ ```
1187
+
1188
+ If you get the error message below, you need to finetune the weights for your downstream task:
1189
+
1190
+ ```bash
1191
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
1192
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
1193
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
1194
+ ```
1195
+ """
1196
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
1197
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
1198
+ force_download = kwargs.pop("force_download", False)
1199
+ from_flax = kwargs.pop("from_flax", False)
1200
+ resume_download = kwargs.pop("resume_download", False)
1201
+ proxies = kwargs.pop("proxies", None)
1202
+ output_loading_info = kwargs.pop("output_loading_info", False)
1203
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
1204
+ use_auth_token = kwargs.pop("use_auth_token", None)
1205
+ revision = kwargs.pop("revision", None)
1206
+ torch_dtype = kwargs.pop("torch_dtype", None)
1207
+ subfolder = kwargs.pop("subfolder", None)
1208
+ device_map = kwargs.pop("device_map", None)
1209
+ max_memory = kwargs.pop("max_memory", None)
1210
+ offload_folder = kwargs.pop("offload_folder", None)
1211
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
1212
+ variant = kwargs.pop("variant", None)
1213
+ use_safetensors = kwargs.pop("use_safetensors", None)
1214
+
1215
+ # if use_safetensors and not is_safetensors_available():
1216
+ # raise ValueError(
1217
+ # "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
1218
+ # )
1219
+
1220
+ allow_pickle = False
1221
+ if use_safetensors is None:
1222
+ # use_safetensors = is_safetensors_available()
1223
+ use_safetensors = False
1224
+ allow_pickle = True
1225
+
1226
+ if device_map is not None and not is_accelerate_available():
1227
+ raise NotImplementedError(
1228
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
1229
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
1230
+ )
1231
+
1232
+ # Check if we can handle device_map and dispatching the weights
1233
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
1234
+ raise NotImplementedError(
1235
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
1236
+ " `device_map=None`."
1237
+ )
1238
+
1239
+ # Load config if we don't provide a configuration
1240
+ config_path = pretrained_model_name_or_path
1241
+
1242
+ user_agent = {
1243
+ "diffusers": __version__,
1244
+ "file_type": "model",
1245
+ "framework": "pytorch",
1246
+ }
1247
+
1248
+ # load config
1249
+ config, unused_kwargs, commit_hash = cls.load_config(
1250
+ config_path,
1251
+ cache_dir=cache_dir,
1252
+ return_unused_kwargs=True,
1253
+ return_commit_hash=True,
1254
+ force_download=force_download,
1255
+ resume_download=resume_download,
1256
+ proxies=proxies,
1257
+ local_files_only=local_files_only,
1258
+ use_auth_token=use_auth_token,
1259
+ revision=revision,
1260
+ subfolder=subfolder,
1261
+ device_map=device_map,
1262
+ max_memory=max_memory,
1263
+ offload_folder=offload_folder,
1264
+ offload_state_dict=offload_state_dict,
1265
+ user_agent=user_agent,
1266
+ **kwargs,
1267
+ )
1268
+
1269
+ # modify config
1270
+ config["_class_name"] = cls.__name__
1271
+ config['in_channels'] = in_channels
1272
+ config['out_channels'] = out_channels
1273
+ config['sample_size'] = sample_size # training resolution
1274
+ config['num_views'] = num_views
1275
+ config['joint_attention'] = joint_attention
1276
+ config['joint_attention_twice'] = joint_attention_twice
1277
+ config['multiview_attention'] = multiview_attention
1278
+ config['cross_domain_attention'] = cross_domain_attention
1279
+ config["down_block_types"] = [
1280
+ "CrossAttnDownBlockMV2D",
1281
+ "CrossAttnDownBlockMV2D",
1282
+ "CrossAttnDownBlockMV2D",
1283
+ "DownBlock2D"
1284
+ ]
1285
+ config['mid_block_type'] = "UNetMidBlockMV2DCrossAttn"
1286
+ config["up_block_types"] = [
1287
+ "UpBlock2D",
1288
+ "CrossAttnUpBlockMV2D",
1289
+ "CrossAttnUpBlockMV2D",
1290
+ "CrossAttnUpBlockMV2D"
1291
+ ]
1292
+ config['class_embed_type'] = 'projection'
1293
+ if camera_embedding_type == 'e_de_da_sincos':
1294
+ config['projection_class_embeddings_input_dim'] = projection_class_embeddings_input_dim # default 6
1295
+ else:
1296
+ raise NotImplementedError
1297
+
1298
+ # load model
1299
+ model_file = None
1300
+ if from_flax:
1301
+ raise NotImplementedError
1302
+ else:
1303
+ if use_safetensors:
1304
+ try:
1305
+ model_file = _get_model_file(
1306
+ pretrained_model_name_or_path,
1307
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
1308
+ cache_dir=cache_dir,
1309
+ force_download=force_download,
1310
+ resume_download=resume_download,
1311
+ proxies=proxies,
1312
+ local_files_only=local_files_only,
1313
+ use_auth_token=use_auth_token,
1314
+ revision=revision,
1315
+ subfolder=subfolder,
1316
+ user_agent=user_agent,
1317
+ commit_hash=commit_hash,
1318
+ )
1319
+ except IOError as e:
1320
+ if not allow_pickle:
1321
+ raise e
1322
+ pass
1323
+ if model_file is None:
1324
+ model_file = _get_model_file(
1325
+ pretrained_model_name_or_path,
1326
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
1327
+ cache_dir=cache_dir,
1328
+ force_download=force_download,
1329
+ resume_download=resume_download,
1330
+ proxies=proxies,
1331
+ local_files_only=local_files_only,
1332
+ use_auth_token=use_auth_token,
1333
+ revision=revision,
1334
+ subfolder=subfolder,
1335
+ user_agent=user_agent,
1336
+ commit_hash=commit_hash,
1337
+ )
1338
+
1339
+ model = cls.from_config(config, **unused_kwargs)
1340
+ if local_crossattn:
1341
+ unet_lora_attn_procs = dict()
1342
+ for name, _ in model.attn_processors.items():
1343
+ if not name.endswith("attn1.processor"):
1344
+ default_attn_proc = AttnProcessor()
1345
+ elif is_xformers_available():
1346
+ default_attn_proc = XFormersMVAttnProcessor()
1347
+ else:
1348
+ default_attn_proc = MVAttnProcessor()
1349
+ unet_lora_attn_procs[name] = ReferenceOnlyAttnProc(
1350
+ default_attn_proc, enabled=name.endswith("attn1.processor"), name=name
1351
+ )
1352
+ model.set_attn_processor(unet_lora_attn_procs)
1353
+ state_dict = load_state_dict(model_file, variant=variant)
1354
+ model._convert_deprecated_attention_blocks(state_dict)
1355
+
1356
+ conv_in_weight = state_dict['conv_in.weight']
1357
+ conv_out_weight = state_dict['conv_out.weight']
1358
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d(
1359
+ model,
1360
+ state_dict,
1361
+ model_file,
1362
+ pretrained_model_name_or_path,
1363
+ ignore_mismatched_sizes=True,
1364
+ )
1365
+ if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]):
1366
+ # initialize from the original SD structure
1367
+ model.conv_in.weight.data[:,:4] = conv_in_weight
1368
+
1369
+ # whether to place all zero to new layers?
1370
+ if zero_init_conv_in:
1371
+ model.conv_in.weight.data[:,4:] = 0.
1372
+
1373
+ if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]):
1374
+ # initialize from the original SD structure
1375
+ model.conv_out.weight.data[:,:4] = conv_out_weight
1376
+ if out_channels == 8: # copy for the last 4 channels
1377
+ model.conv_out.weight.data[:, 4:] = conv_out_weight
1378
+
1379
+ if zero_init_camera_projection:
1380
+ for p in model.class_embedding.parameters():
1381
+ torch.nn.init.zeros_(p)
1382
+
1383
+ loading_info = {
1384
+ "missing_keys": missing_keys,
1385
+ "unexpected_keys": unexpected_keys,
1386
+ "mismatched_keys": mismatched_keys,
1387
+ "error_msgs": error_msgs,
1388
+ }
1389
+
1390
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
1391
+ raise ValueError(
1392
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
1393
+ )
1394
+ elif torch_dtype is not None:
1395
+ model = model.to(torch_dtype)
1396
+
1397
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
1398
+
1399
+ # Set model in evaluation mode to deactivate DropOut modules by default
1400
+ model.eval()
1401
+ if output_loading_info:
1402
+ return model, loading_info
1403
+
1404
+ return model
1405
+
1406
+ @classmethod
1407
+ def _load_pretrained_model_2d(
1408
+ cls,
1409
+ model,
1410
+ state_dict,
1411
+ resolved_archive_file,
1412
+ pretrained_model_name_or_path,
1413
+ ignore_mismatched_sizes=False,
1414
+ ):
1415
+ # Retrieve missing & unexpected_keys
1416
+ model_state_dict = model.state_dict()
1417
+ loaded_keys = list(state_dict.keys())
1418
+
1419
+ expected_keys = list(model_state_dict.keys())
1420
+
1421
+ original_loaded_keys = loaded_keys
1422
+
1423
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
1424
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
1425
+
1426
+ # Make sure we are able to load base models as well as derived models (with heads)
1427
+ model_to_load = model
1428
+
1429
+ def _find_mismatched_keys(
1430
+ state_dict,
1431
+ model_state_dict,
1432
+ loaded_keys,
1433
+ ignore_mismatched_sizes,
1434
+ ):
1435
+ mismatched_keys = []
1436
+ if ignore_mismatched_sizes:
1437
+ for checkpoint_key in loaded_keys:
1438
+ model_key = checkpoint_key
1439
+
1440
+ if (
1441
+ model_key in model_state_dict
1442
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
1443
+ ):
1444
+ mismatched_keys.append(
1445
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
1446
+ )
1447
+ del state_dict[checkpoint_key]
1448
+ return mismatched_keys
1449
+
1450
+ if state_dict is not None:
1451
+ # Whole checkpoint
1452
+ mismatched_keys = _find_mismatched_keys(
1453
+ state_dict,
1454
+ model_state_dict,
1455
+ original_loaded_keys,
1456
+ ignore_mismatched_sizes,
1457
+ )
1458
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
1459
+
1460
+ if len(error_msgs) > 0:
1461
+ error_msg = "\n\t".join(error_msgs)
1462
+ if "size mismatch" in error_msg:
1463
+ error_msg += (
1464
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
1465
+ )
1466
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
1467
+
1468
+ if len(unexpected_keys) > 0:
1469
+ logger.warning(
1470
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
1471
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
1472
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
1473
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
1474
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
1475
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
1476
+ " identical (initializing a BertForSequenceClassification model from a"
1477
+ " BertForSequenceClassification model)."
1478
+ )
1479
+ else:
1480
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
1481
+ if len(missing_keys) > 0:
1482
+ logger.warning(
1483
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1484
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
1485
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
1486
+ )
1487
+ elif len(mismatched_keys) == 0:
1488
+ logger.info(
1489
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
1490
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
1491
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
1492
+ " without further training."
1493
+ )
1494
+ if len(mismatched_keys) > 0:
1495
+ mismatched_warning = "\n".join(
1496
+ [
1497
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
1498
+ for key, shape1, shape2 in mismatched_keys
1499
+ ]
1500
+ )
1501
+ logger.warning(
1502
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1503
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
1504
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
1505
+ " able to use it for predictions and inference."
1506
+ )
1507
+
1508
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
1509
+
2D_Stage/tuneavideo/models/unet_mv2d_ref.py ADDED
@@ -0,0 +1,1570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+ import os
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+ from einops import rearrange
22
+
23
+
24
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+ from diffusers.loaders import UNet2DConditionLoadersMixin
26
+ from diffusers.utils import BaseOutput, logging
27
+ from diffusers.models.activations import get_activation
28
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
29
+ from diffusers.models.embeddings import (
30
+ GaussianFourierProjection,
31
+ ImageHintTimeEmbedding,
32
+ ImageProjection,
33
+ ImageTimeEmbedding,
34
+ TextImageProjection,
35
+ TextImageTimeEmbedding,
36
+ TextTimeEmbedding,
37
+ TimestepEmbedding,
38
+ Timesteps,
39
+ )
40
+ from diffusers.models.lora import LoRALinearLayer
41
+
42
+ from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model
43
+ from diffusers.models.unet_2d_blocks import (
44
+ CrossAttnDownBlock2D,
45
+ CrossAttnUpBlock2D,
46
+ DownBlock2D,
47
+ UNetMidBlock2DCrossAttn,
48
+ UNetMidBlock2DSimpleCrossAttn,
49
+ UpBlock2D,
50
+ )
51
+ from diffusers.utils import (
52
+ CONFIG_NAME,
53
+ DIFFUSERS_CACHE,
54
+ FLAX_WEIGHTS_NAME,
55
+ HF_HUB_OFFLINE,
56
+ SAFETENSORS_WEIGHTS_NAME,
57
+ WEIGHTS_NAME,
58
+ _add_variant,
59
+ _get_model_file,
60
+ deprecate,
61
+ is_accelerate_available,
62
+ is_torch_version,
63
+ logging,
64
+ )
65
+ from diffusers import __version__
66
+ from tuneavideo.models.unet_mv2d_blocks import (
67
+ CrossAttnDownBlockMV2D,
68
+ CrossAttnUpBlockMV2D,
69
+ UNetMidBlockMV2DCrossAttn,
70
+ get_down_block,
71
+ get_up_block,
72
+ )
73
+ from diffusers.models.attention_processor import Attention, AttnProcessor
74
+ from diffusers.utils.import_utils import is_xformers_available
75
+ from tuneavideo.models.transformer_mv2d import XFormersMVAttnProcessor, MVAttnProcessor
76
+ from tuneavideo.models.refunet import ReferenceOnlyAttnProc
77
+
78
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
79
+
80
+
81
+ @dataclass
82
+ class UNetMV2DRefOutput(BaseOutput):
83
+ """
84
+ The output of [`UNet2DConditionModel`].
85
+
86
+ Args:
87
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
88
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
89
+ """
90
+
91
+ sample: torch.FloatTensor = None
92
+
93
+ class Identity(torch.nn.Module):
94
+ r"""A placeholder identity operator that is argument-insensitive.
95
+
96
+ Args:
97
+ args: any argument (unused)
98
+ kwargs: any keyword argument (unused)
99
+
100
+ Shape:
101
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
102
+ - Output: :math:`(*)`, same shape as the input.
103
+
104
+ Examples::
105
+
106
+ >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
107
+ >>> input = torch.randn(128, 20)
108
+ >>> output = m(input)
109
+ >>> print(output.size())
110
+ torch.Size([128, 20])
111
+
112
+ """
113
+ def __init__(self, scale=None, *args, **kwargs) -> None:
114
+ super(Identity, self).__init__()
115
+
116
+ def forward(self, input, *args, **kwargs):
117
+ return input
118
+
119
+
120
+
121
+ class _LoRACompatibleLinear(nn.Module):
122
+ """
123
+ A Linear layer that can be used with LoRA.
124
+ """
125
+
126
+ def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
127
+ super().__init__(*args, **kwargs)
128
+ self.lora_layer = lora_layer
129
+
130
+ def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
131
+ self.lora_layer = lora_layer
132
+
133
+ def _fuse_lora(self):
134
+ pass
135
+
136
+ def _unfuse_lora(self):
137
+ pass
138
+
139
+ def forward(self, hidden_states, scale=None, lora_scale: int = 1):
140
+ return hidden_states
141
+
142
+ class UNetMV2DRefModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
143
+ r"""
144
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
145
+ shaped output.
146
+
147
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
148
+ for all models (such as downloading or saving).
149
+
150
+ Parameters:
151
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
152
+ Height and width of input/output sample.
153
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
154
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
155
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
156
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
157
+ Whether to flip the sin to cos in the time embedding.
158
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
159
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
160
+ The tuple of downsample blocks to use.
161
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
162
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
163
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
164
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
165
+ The tuple of upsample blocks to use.
166
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
167
+ Whether to include self-attention in the basic transformer blocks, see
168
+ [`~models.attention.BasicTransformerBlock`].
169
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
170
+ The tuple of output channels for each block.
171
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
172
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
173
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
174
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
175
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
176
+ If `None`, normalization and activation layers is skipped in post-processing.
177
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
178
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
179
+ The dimension of the cross attention features.
180
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
181
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
182
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
183
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
184
+ encoder_hid_dim (`int`, *optional*, defaults to None):
185
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
186
+ dimension to `cross_attention_dim`.
187
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
188
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
189
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
190
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
191
+ num_attention_heads (`int`, *optional*):
192
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
193
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
194
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
195
+ class_embed_type (`str`, *optional*, defaults to `None`):
196
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
197
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
198
+ addition_embed_type (`str`, *optional*, defaults to `None`):
199
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
200
+ "text". "text" will use the `TextTimeEmbedding` layer.
201
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
202
+ Dimension for the timestep embeddings.
203
+ num_class_embeds (`int`, *optional*, defaults to `None`):
204
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
205
+ class conditioning with `class_embed_type` equal to `None`.
206
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
207
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
208
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
209
+ An optional override for the dimension of the projected time embedding.
210
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
211
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
212
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
213
+ timestep_post_act (`str`, *optional*, defaults to `None`):
214
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
215
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
216
+ The dimension of `cond_proj` layer in the timestep embedding.
217
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
218
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
219
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
220
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
221
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
222
+ embeddings with the class embeddings.
223
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
224
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
225
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
226
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
227
+ otherwise.
228
+ """
229
+
230
+ _supports_gradient_checkpointing = True
231
+
232
+ @register_to_config
233
+ def __init__(
234
+ self,
235
+ sample_size: Optional[int] = None,
236
+ in_channels: int = 4,
237
+ out_channels: int = 4,
238
+ center_input_sample: bool = False,
239
+ flip_sin_to_cos: bool = True,
240
+ freq_shift: int = 0,
241
+ down_block_types: Tuple[str] = (
242
+ "CrossAttnDownBlockMV2D",
243
+ "CrossAttnDownBlockMV2D",
244
+ "CrossAttnDownBlockMV2D",
245
+ "DownBlock2D",
246
+ ),
247
+ mid_block_type: Optional[str] = "UNetMidBlockMV2DCrossAttn",
248
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D"),
249
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
250
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
251
+ layers_per_block: Union[int, Tuple[int]] = 2,
252
+ downsample_padding: int = 1,
253
+ mid_block_scale_factor: float = 1,
254
+ act_fn: str = "silu",
255
+ norm_num_groups: Optional[int] = 32,
256
+ norm_eps: float = 1e-5,
257
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
258
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
259
+ encoder_hid_dim: Optional[int] = None,
260
+ encoder_hid_dim_type: Optional[str] = None,
261
+ attention_head_dim: Union[int, Tuple[int]] = 8,
262
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
263
+ dual_cross_attention: bool = False,
264
+ use_linear_projection: bool = False,
265
+ class_embed_type: Optional[str] = None,
266
+ addition_embed_type: Optional[str] = None,
267
+ addition_time_embed_dim: Optional[int] = None,
268
+ num_class_embeds: Optional[int] = None,
269
+ upcast_attention: bool = False,
270
+ resnet_time_scale_shift: str = "default",
271
+ resnet_skip_time_act: bool = False,
272
+ resnet_out_scale_factor: int = 1.0,
273
+ time_embedding_type: str = "positional",
274
+ time_embedding_dim: Optional[int] = None,
275
+ time_embedding_act_fn: Optional[str] = None,
276
+ timestep_post_act: Optional[str] = None,
277
+ time_cond_proj_dim: Optional[int] = None,
278
+ conv_in_kernel: int = 3,
279
+ conv_out_kernel: int = 3,
280
+ projection_class_embeddings_input_dim: Optional[int] = None,
281
+ class_embeddings_concat: bool = False,
282
+ mid_block_only_cross_attention: Optional[bool] = None,
283
+ cross_attention_norm: Optional[str] = None,
284
+ addition_embed_type_num_heads=64,
285
+ num_views: int = 1,
286
+ joint_attention: bool = False,
287
+ joint_attention_twice: bool = False,
288
+ multiview_attention: bool = True,
289
+ cross_domain_attention: bool = False,
290
+ camera_input_dim: int = 12,
291
+ camera_hidden_dim: int = 320,
292
+ camera_output_dim: int = 1280,
293
+
294
+ ):
295
+ super().__init__()
296
+
297
+ self.sample_size = sample_size
298
+
299
+ if num_attention_heads is not None:
300
+ raise ValueError(
301
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
302
+ )
303
+
304
+ # If `num_attention_heads` is not defined (which is the case for most models)
305
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
306
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
307
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
308
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
309
+ # which is why we correct for the naming here.
310
+ num_attention_heads = num_attention_heads or attention_head_dim
311
+
312
+ # Check inputs
313
+ if len(down_block_types) != len(up_block_types):
314
+ raise ValueError(
315
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
316
+ )
317
+
318
+ if len(block_out_channels) != len(down_block_types):
319
+ raise ValueError(
320
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
321
+ )
322
+
323
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
324
+ raise ValueError(
325
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
326
+ )
327
+
328
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
329
+ raise ValueError(
330
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
331
+ )
332
+
333
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
334
+ raise ValueError(
335
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
336
+ )
337
+
338
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
339
+ raise ValueError(
340
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
341
+ )
342
+
343
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
344
+ raise ValueError(
345
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
346
+ )
347
+
348
+ # input
349
+ conv_in_padding = (conv_in_kernel - 1) // 2
350
+ self.conv_in = nn.Conv2d(
351
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
352
+ )
353
+
354
+ # time
355
+ if time_embedding_type == "fourier":
356
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
357
+ if time_embed_dim % 2 != 0:
358
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
359
+ self.time_proj = GaussianFourierProjection(
360
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
361
+ )
362
+ timestep_input_dim = time_embed_dim
363
+ elif time_embedding_type == "positional":
364
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
365
+
366
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
367
+ timestep_input_dim = block_out_channels[0]
368
+ else:
369
+ raise ValueError(
370
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
371
+ )
372
+
373
+ self.time_embedding = TimestepEmbedding(
374
+ timestep_input_dim,
375
+ time_embed_dim,
376
+ act_fn=act_fn,
377
+ post_act_fn=timestep_post_act,
378
+ cond_proj_dim=time_cond_proj_dim,
379
+ )
380
+
381
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
382
+ encoder_hid_dim_type = "text_proj"
383
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
384
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
385
+
386
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
387
+ raise ValueError(
388
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
389
+ )
390
+
391
+ if encoder_hid_dim_type == "text_proj":
392
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
393
+ elif encoder_hid_dim_type == "text_image_proj":
394
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
395
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
396
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
397
+ self.encoder_hid_proj = TextImageProjection(
398
+ text_embed_dim=encoder_hid_dim,
399
+ image_embed_dim=cross_attention_dim,
400
+ cross_attention_dim=cross_attention_dim,
401
+ )
402
+ elif encoder_hid_dim_type == "image_proj":
403
+ # Kandinsky 2.2
404
+ self.encoder_hid_proj = ImageProjection(
405
+ image_embed_dim=encoder_hid_dim,
406
+ cross_attention_dim=cross_attention_dim,
407
+ )
408
+ elif encoder_hid_dim_type is not None:
409
+ raise ValueError(
410
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
411
+ )
412
+ else:
413
+ self.encoder_hid_proj = None
414
+
415
+ # class embedding
416
+ if class_embed_type is None and num_class_embeds is not None:
417
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
418
+ elif class_embed_type == "timestep":
419
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
420
+ elif class_embed_type == "identity":
421
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
422
+ elif class_embed_type == "projection":
423
+ if projection_class_embeddings_input_dim is None:
424
+ raise ValueError(
425
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
426
+ )
427
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
428
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
429
+ # 2. it projects from an arbitrary input dimension.
430
+ #
431
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
432
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
433
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
434
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
435
+ elif class_embed_type == "simple_projection":
436
+ if projection_class_embeddings_input_dim is None:
437
+ raise ValueError(
438
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
439
+ )
440
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
441
+ else:
442
+ self.class_embedding = None
443
+
444
+ if addition_embed_type == "text":
445
+ if encoder_hid_dim is not None:
446
+ text_time_embedding_from_dim = encoder_hid_dim
447
+ else:
448
+ text_time_embedding_from_dim = cross_attention_dim
449
+
450
+ self.add_embedding = TextTimeEmbedding(
451
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
452
+ )
453
+ elif addition_embed_type == "text_image":
454
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
455
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
456
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
457
+ self.add_embedding = TextImageTimeEmbedding(
458
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
459
+ )
460
+ elif addition_embed_type == "text_time":
461
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
462
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
463
+ elif addition_embed_type == "image":
464
+ # Kandinsky 2.2
465
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
466
+ elif addition_embed_type == "image_hint":
467
+ # Kandinsky 2.2 ControlNet
468
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
469
+ elif addition_embed_type is not None:
470
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
471
+
472
+ if time_embedding_act_fn is None:
473
+ self.time_embed_act = None
474
+ else:
475
+ self.time_embed_act = get_activation(time_embedding_act_fn)
476
+
477
+ self.camera_embedding = nn.Sequential(
478
+ nn.Linear(camera_input_dim, time_embed_dim),
479
+ nn.SiLU(),
480
+ nn.Linear(time_embed_dim, time_embed_dim),
481
+ )
482
+
483
+ self.down_blocks = nn.ModuleList([])
484
+ self.up_blocks = nn.ModuleList([])
485
+
486
+ if isinstance(only_cross_attention, bool):
487
+ if mid_block_only_cross_attention is None:
488
+ mid_block_only_cross_attention = only_cross_attention
489
+
490
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
491
+
492
+ if mid_block_only_cross_attention is None:
493
+ mid_block_only_cross_attention = False
494
+
495
+ if isinstance(num_attention_heads, int):
496
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
497
+
498
+ if isinstance(attention_head_dim, int):
499
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
500
+
501
+ if isinstance(cross_attention_dim, int):
502
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
503
+
504
+ if isinstance(layers_per_block, int):
505
+ layers_per_block = [layers_per_block] * len(down_block_types)
506
+
507
+ if isinstance(transformer_layers_per_block, int):
508
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
509
+
510
+ if class_embeddings_concat:
511
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
512
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
513
+ # regular time embeddings
514
+ blocks_time_embed_dim = time_embed_dim * 2
515
+ else:
516
+ blocks_time_embed_dim = time_embed_dim
517
+
518
+ # down
519
+ output_channel = block_out_channels[0]
520
+ for i, down_block_type in enumerate(down_block_types):
521
+ input_channel = output_channel
522
+ output_channel = block_out_channels[i]
523
+ is_final_block = i == len(block_out_channels) - 1
524
+
525
+ down_block = get_down_block(
526
+ down_block_type,
527
+ num_layers=layers_per_block[i],
528
+ transformer_layers_per_block=transformer_layers_per_block[i],
529
+ in_channels=input_channel,
530
+ out_channels=output_channel,
531
+ temb_channels=blocks_time_embed_dim,
532
+ add_downsample=not is_final_block,
533
+ resnet_eps=norm_eps,
534
+ resnet_act_fn=act_fn,
535
+ resnet_groups=norm_num_groups,
536
+ cross_attention_dim=cross_attention_dim[i],
537
+ num_attention_heads=num_attention_heads[i],
538
+ downsample_padding=downsample_padding,
539
+ dual_cross_attention=dual_cross_attention,
540
+ use_linear_projection=use_linear_projection,
541
+ only_cross_attention=only_cross_attention[i],
542
+ upcast_attention=upcast_attention,
543
+ resnet_time_scale_shift=resnet_time_scale_shift,
544
+ resnet_skip_time_act=resnet_skip_time_act,
545
+ resnet_out_scale_factor=resnet_out_scale_factor,
546
+ cross_attention_norm=cross_attention_norm,
547
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
548
+ num_views=num_views,
549
+ joint_attention=joint_attention,
550
+ joint_attention_twice=joint_attention_twice,
551
+ multiview_attention=multiview_attention,
552
+ cross_domain_attention=cross_domain_attention
553
+ )
554
+ self.down_blocks.append(down_block)
555
+
556
+ # mid
557
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
558
+ self.mid_block = UNetMidBlock2DCrossAttn(
559
+ transformer_layers_per_block=transformer_layers_per_block[-1],
560
+ in_channels=block_out_channels[-1],
561
+ temb_channels=blocks_time_embed_dim,
562
+ resnet_eps=norm_eps,
563
+ resnet_act_fn=act_fn,
564
+ output_scale_factor=mid_block_scale_factor,
565
+ resnet_time_scale_shift=resnet_time_scale_shift,
566
+ cross_attention_dim=cross_attention_dim[-1],
567
+ num_attention_heads=num_attention_heads[-1],
568
+ resnet_groups=norm_num_groups,
569
+ dual_cross_attention=dual_cross_attention,
570
+ use_linear_projection=use_linear_projection,
571
+ upcast_attention=upcast_attention,
572
+ )
573
+ # custom MV2D attention block
574
+ elif mid_block_type == "UNetMidBlockMV2DCrossAttn":
575
+ self.mid_block = UNetMidBlockMV2DCrossAttn(
576
+ transformer_layers_per_block=transformer_layers_per_block[-1],
577
+ in_channels=block_out_channels[-1],
578
+ temb_channels=blocks_time_embed_dim,
579
+ resnet_eps=norm_eps,
580
+ resnet_act_fn=act_fn,
581
+ output_scale_factor=mid_block_scale_factor,
582
+ resnet_time_scale_shift=resnet_time_scale_shift,
583
+ cross_attention_dim=cross_attention_dim[-1],
584
+ num_attention_heads=num_attention_heads[-1],
585
+ resnet_groups=norm_num_groups,
586
+ dual_cross_attention=dual_cross_attention,
587
+ use_linear_projection=use_linear_projection,
588
+ upcast_attention=upcast_attention,
589
+ num_views=num_views,
590
+ joint_attention=joint_attention,
591
+ joint_attention_twice=joint_attention_twice,
592
+ multiview_attention=multiview_attention,
593
+ cross_domain_attention=cross_domain_attention
594
+ )
595
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
596
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
597
+ in_channels=block_out_channels[-1],
598
+ temb_channels=blocks_time_embed_dim,
599
+ resnet_eps=norm_eps,
600
+ resnet_act_fn=act_fn,
601
+ output_scale_factor=mid_block_scale_factor,
602
+ cross_attention_dim=cross_attention_dim[-1],
603
+ attention_head_dim=attention_head_dim[-1],
604
+ resnet_groups=norm_num_groups,
605
+ resnet_time_scale_shift=resnet_time_scale_shift,
606
+ skip_time_act=resnet_skip_time_act,
607
+ only_cross_attention=mid_block_only_cross_attention,
608
+ cross_attention_norm=cross_attention_norm,
609
+ )
610
+ elif mid_block_type is None:
611
+ self.mid_block = None
612
+ else:
613
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
614
+
615
+ # count how many layers upsample the images
616
+ self.num_upsamplers = 0
617
+
618
+ # up
619
+ reversed_block_out_channels = list(reversed(block_out_channels))
620
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
621
+ reversed_layers_per_block = list(reversed(layers_per_block))
622
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
623
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
624
+ only_cross_attention = list(reversed(only_cross_attention))
625
+
626
+ output_channel = reversed_block_out_channels[0]
627
+ for i, up_block_type in enumerate(up_block_types):
628
+ is_final_block = i == len(block_out_channels) - 1
629
+
630
+ prev_output_channel = output_channel
631
+ output_channel = reversed_block_out_channels[i]
632
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
633
+
634
+ # add upsample block for all BUT final layer
635
+ if not is_final_block:
636
+ add_upsample = True
637
+ self.num_upsamplers += 1
638
+ else:
639
+ add_upsample = False
640
+
641
+ up_block = get_up_block(
642
+ up_block_type,
643
+ num_layers=reversed_layers_per_block[i] + 1,
644
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
645
+ in_channels=input_channel,
646
+ out_channels=output_channel,
647
+ prev_output_channel=prev_output_channel,
648
+ temb_channels=blocks_time_embed_dim,
649
+ add_upsample=add_upsample,
650
+ resnet_eps=norm_eps,
651
+ resnet_act_fn=act_fn,
652
+ resnet_groups=norm_num_groups,
653
+ cross_attention_dim=reversed_cross_attention_dim[i],
654
+ num_attention_heads=reversed_num_attention_heads[i],
655
+ dual_cross_attention=dual_cross_attention,
656
+ use_linear_projection=use_linear_projection,
657
+ only_cross_attention=only_cross_attention[i],
658
+ upcast_attention=upcast_attention,
659
+ resnet_time_scale_shift=resnet_time_scale_shift,
660
+ resnet_skip_time_act=resnet_skip_time_act,
661
+ resnet_out_scale_factor=resnet_out_scale_factor,
662
+ cross_attention_norm=cross_attention_norm,
663
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
664
+ num_views=num_views,
665
+ joint_attention=joint_attention,
666
+ joint_attention_twice=joint_attention_twice,
667
+ multiview_attention=multiview_attention,
668
+ cross_domain_attention=cross_domain_attention
669
+ )
670
+ self.up_blocks.append(up_block)
671
+ prev_output_channel = output_channel
672
+
673
+ # out
674
+ # if norm_num_groups is not None:
675
+ # self.conv_norm_out = nn.GroupNorm(
676
+ # num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
677
+ # )
678
+
679
+ # self.conv_act = get_activation(act_fn)
680
+
681
+ # else:
682
+ # self.conv_norm_out = None
683
+ # self.conv_act = None
684
+
685
+ # conv_out_padding = (conv_out_kernel - 1) // 2
686
+ # self.conv_out = nn.Conv2d(
687
+ # block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
688
+ # )
689
+
690
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_q = _LoRACompatibleLinear()
691
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_k = _LoRACompatibleLinear()
692
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_v = _LoRACompatibleLinear()
693
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_out = nn.ModuleList([Identity(), Identity()])
694
+ self.up_blocks[3].attentions[2].transformer_blocks[0].norm2 = Identity()
695
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn2 = None
696
+ self.up_blocks[3].attentions[2].transformer_blocks[0].norm3 = Identity()
697
+ self.up_blocks[3].attentions[2].transformer_blocks[0].ff = Identity()
698
+ self.up_blocks[3].attentions[2].proj_out = Identity()
699
+
700
+ @property
701
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
702
+ r"""
703
+ Returns:
704
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
705
+ indexed by its weight name.
706
+ """
707
+ # set recursively
708
+ processors = {}
709
+
710
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
711
+ if hasattr(module, "set_processor"):
712
+ processors[f"{name}.processor"] = module.processor
713
+
714
+ for sub_name, child in module.named_children():
715
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
716
+
717
+ return processors
718
+
719
+ for name, module in self.named_children():
720
+ fn_recursive_add_processors(name, module, processors)
721
+
722
+ return processors
723
+
724
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
725
+ r"""
726
+ Sets the attention processor to use to compute attention.
727
+
728
+ Parameters:
729
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
730
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
731
+ for **all** `Attention` layers.
732
+
733
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
734
+ processor. This is strongly recommended when setting trainable attention processors.
735
+
736
+ """
737
+ count = len(self.attn_processors.keys())
738
+
739
+ if isinstance(processor, dict) and len(processor) != count:
740
+ raise ValueError(
741
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
742
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
743
+ )
744
+
745
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
746
+ if hasattr(module, "set_processor"):
747
+ if not isinstance(processor, dict):
748
+ module.set_processor(processor)
749
+ else:
750
+ module.set_processor(processor.pop(f"{name}.processor"))
751
+
752
+ for sub_name, child in module.named_children():
753
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
754
+
755
+ for name, module in self.named_children():
756
+ fn_recursive_attn_processor(name, module, processor)
757
+
758
+ def set_default_attn_processor(self):
759
+ """
760
+ Disables custom attention processors and sets the default attention implementation.
761
+ """
762
+ self.set_attn_processor(AttnProcessor())
763
+
764
+ def set_attention_slice(self, slice_size):
765
+ r"""
766
+ Enable sliced attention computation.
767
+
768
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
769
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
770
+
771
+ Args:
772
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
773
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
774
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
775
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
776
+ must be a multiple of `slice_size`.
777
+ """
778
+ sliceable_head_dims = []
779
+
780
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
781
+ if hasattr(module, "set_attention_slice"):
782
+ sliceable_head_dims.append(module.sliceable_head_dim)
783
+
784
+ for child in module.children():
785
+ fn_recursive_retrieve_sliceable_dims(child)
786
+
787
+ # retrieve number of attention layers
788
+ for module in self.children():
789
+ fn_recursive_retrieve_sliceable_dims(module)
790
+
791
+ num_sliceable_layers = len(sliceable_head_dims)
792
+
793
+ if slice_size == "auto":
794
+ # half the attention head size is usually a good trade-off between
795
+ # speed and memory
796
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
797
+ elif slice_size == "max":
798
+ # make smallest slice possible
799
+ slice_size = num_sliceable_layers * [1]
800
+
801
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
802
+
803
+ if len(slice_size) != len(sliceable_head_dims):
804
+ raise ValueError(
805
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
806
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
807
+ )
808
+
809
+ for i in range(len(slice_size)):
810
+ size = slice_size[i]
811
+ dim = sliceable_head_dims[i]
812
+ if size is not None and size > dim:
813
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
814
+
815
+ # Recursively walk through all the children.
816
+ # Any children which exposes the set_attention_slice method
817
+ # gets the message
818
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
819
+ if hasattr(module, "set_attention_slice"):
820
+ module.set_attention_slice(slice_size.pop())
821
+
822
+ for child in module.children():
823
+ fn_recursive_set_attention_slice(child, slice_size)
824
+
825
+ reversed_slice_size = list(reversed(slice_size))
826
+ for module in self.children():
827
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
828
+
829
+ def _set_gradient_checkpointing(self, module, value=False):
830
+ if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)):
831
+ module.gradient_checkpointing = value
832
+
833
+ def forward(
834
+ self,
835
+ sample: torch.FloatTensor,
836
+ timestep: Union[torch.Tensor, float, int],
837
+ encoder_hidden_states: torch.Tensor,
838
+ camera_matrixs: Optional[torch.Tensor] = None,
839
+ class_labels: Optional[torch.Tensor] = None,
840
+ timestep_cond: Optional[torch.Tensor] = None,
841
+ attention_mask: Optional[torch.Tensor] = None,
842
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
843
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
844
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
845
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
846
+ encoder_attention_mask: Optional[torch.Tensor] = None,
847
+ return_dict: bool = True,
848
+ ) -> Union[UNetMV2DRefOutput, Tuple]:
849
+ r"""
850
+ The [`UNet2DConditionModel`] forward method.
851
+
852
+ Args:
853
+ sample (`torch.FloatTensor`):
854
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
855
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
856
+ encoder_hidden_states (`torch.FloatTensor`):
857
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
858
+ encoder_attention_mask (`torch.Tensor`):
859
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
860
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
861
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
862
+ return_dict (`bool`, *optional*, defaults to `True`):
863
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
864
+ tuple.
865
+ cross_attention_kwargs (`dict`, *optional*):
866
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
867
+ added_cond_kwargs: (`dict`, *optional*):
868
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
869
+ are passed along to the UNet blocks.
870
+
871
+ Returns:
872
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
873
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
874
+ a `tuple` is returned where the first element is the sample tensor.
875
+ """
876
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
877
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
878
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
879
+ # on the fly if necessary.
880
+ default_overall_up_factor = 2**self.num_upsamplers
881
+
882
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
883
+ forward_upsample_size = False
884
+ upsample_size = None
885
+
886
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
887
+ logger.info("Forward upsample size to force interpolation output size.")
888
+ forward_upsample_size = True
889
+
890
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
891
+ # expects mask of shape:
892
+ # [batch, key_tokens]
893
+ # adds singleton query_tokens dimension:
894
+ # [batch, 1, key_tokens]
895
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
896
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
897
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
898
+ if attention_mask is not None:
899
+ # assume that mask is expressed as:
900
+ # (1 = keep, 0 = discard)
901
+ # convert mask into a bias that can be added to attention scores:
902
+ # (keep = +0, discard = -10000.0)
903
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
904
+ attention_mask = attention_mask.unsqueeze(1)
905
+
906
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
907
+ if encoder_attention_mask is not None:
908
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
909
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
910
+
911
+ # 0. center input if necessary
912
+ if self.config.center_input_sample:
913
+ sample = 2 * sample - 1.0
914
+
915
+ # 1. time
916
+ timesteps = timestep
917
+ if not torch.is_tensor(timesteps):
918
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
919
+ # This would be a good case for the `match` statement (Python 3.10+)
920
+ is_mps = sample.device.type == "mps"
921
+ if isinstance(timestep, float):
922
+ dtype = torch.float32 if is_mps else torch.float64
923
+ else:
924
+ dtype = torch.int32 if is_mps else torch.int64
925
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
926
+ elif len(timesteps.shape) == 0:
927
+ timesteps = timesteps[None].to(sample.device)
928
+
929
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
930
+ timesteps = timesteps.expand(sample.shape[0])
931
+
932
+ t_emb = self.time_proj(timesteps)
933
+
934
+ # `Timesteps` does not contain any weights and will always return f32 tensors
935
+ # but time_embedding might actually be running in fp16. so we need to cast here.
936
+ # there might be better ways to encapsulate this.
937
+ t_emb = t_emb.to(dtype=sample.dtype)
938
+ emb = self.time_embedding(t_emb, timestep_cond)
939
+
940
+ # import pdb; pdb.set_trace()
941
+ if camera_matrixs is not None:
942
+ emb = torch.unsqueeze(emb, 1)
943
+ # came emb
944
+ cam_emb = self.camera_embedding(camera_matrixs)
945
+ # cam_emb = self.camera_embedding_2(cam_emb)
946
+ emb = emb.repeat(1,cam_emb.shape[1],1) #torch.Size([32, 4, 1280])
947
+ emb = emb + cam_emb
948
+ emb = rearrange(emb, "b f c -> (b f) c", f=emb.shape[1])
949
+
950
+ aug_emb = None
951
+
952
+ if self.class_embedding is not None and class_labels is not None:
953
+ if class_labels is None:
954
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
955
+
956
+ if self.config.class_embed_type == "timestep":
957
+ class_labels = self.time_proj(class_labels)
958
+
959
+ # `Timesteps` does not contain any weights and will always return f32 tensors
960
+ # there might be better ways to encapsulate this.
961
+ class_labels = class_labels.to(dtype=sample.dtype)
962
+
963
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
964
+
965
+ if self.config.class_embeddings_concat:
966
+ emb = torch.cat([emb, class_emb], dim=-1)
967
+ else:
968
+ emb = emb + class_emb
969
+
970
+ if self.config.addition_embed_type == "text":
971
+ aug_emb = self.add_embedding(encoder_hidden_states)
972
+ elif self.config.addition_embed_type == "text_image":
973
+ # Kandinsky 2.1 - style
974
+ if "image_embeds" not in added_cond_kwargs:
975
+ raise ValueError(
976
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
977
+ )
978
+
979
+ image_embs = added_cond_kwargs.get("image_embeds")
980
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
981
+ aug_emb = self.add_embedding(text_embs, image_embs)
982
+ elif self.config.addition_embed_type == "text_time":
983
+ # SDXL - style
984
+ if "text_embeds" not in added_cond_kwargs:
985
+ raise ValueError(
986
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
987
+ )
988
+ text_embeds = added_cond_kwargs.get("text_embeds")
989
+ if "time_ids" not in added_cond_kwargs:
990
+ raise ValueError(
991
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
992
+ )
993
+ time_ids = added_cond_kwargs.get("time_ids")
994
+ time_embeds = self.add_time_proj(time_ids.flatten())
995
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
996
+
997
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
998
+ add_embeds = add_embeds.to(emb.dtype)
999
+ aug_emb = self.add_embedding(add_embeds)
1000
+ elif self.config.addition_embed_type == "image":
1001
+ # Kandinsky 2.2 - style
1002
+ if "image_embeds" not in added_cond_kwargs:
1003
+ raise ValueError(
1004
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1005
+ )
1006
+ image_embs = added_cond_kwargs.get("image_embeds")
1007
+ aug_emb = self.add_embedding(image_embs)
1008
+ elif self.config.addition_embed_type == "image_hint":
1009
+ # Kandinsky 2.2 - style
1010
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
1011
+ raise ValueError(
1012
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1013
+ )
1014
+ image_embs = added_cond_kwargs.get("image_embeds")
1015
+ hint = added_cond_kwargs.get("hint")
1016
+ aug_emb, hint = self.add_embedding(image_embs, hint)
1017
+ sample = torch.cat([sample, hint], dim=1)
1018
+
1019
+ emb = emb + aug_emb if aug_emb is not None else emb
1020
+
1021
+ if self.time_embed_act is not None:
1022
+ emb = self.time_embed_act(emb)
1023
+
1024
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1025
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1026
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1027
+ # Kadinsky 2.1 - style
1028
+ if "image_embeds" not in added_cond_kwargs:
1029
+ raise ValueError(
1030
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1031
+ )
1032
+
1033
+ image_embeds = added_cond_kwargs.get("image_embeds")
1034
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1035
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1036
+ # Kandinsky 2.2 - style
1037
+ if "image_embeds" not in added_cond_kwargs:
1038
+ raise ValueError(
1039
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1040
+ )
1041
+ image_embeds = added_cond_kwargs.get("image_embeds")
1042
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1043
+ # 2. pre-process
1044
+ sample = rearrange(sample, "b c f h w -> (b f) c h w", f=sample.shape[2])
1045
+ sample = self.conv_in(sample)
1046
+ # 3. down
1047
+
1048
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1049
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
1050
+
1051
+ down_block_res_samples = (sample,)
1052
+ for downsample_block in self.down_blocks:
1053
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1054
+ # For t2i-adapter CrossAttnDownBlock2D
1055
+ additional_residuals = {}
1056
+ if is_adapter and len(down_block_additional_residuals) > 0:
1057
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
1058
+
1059
+ sample, res_samples = downsample_block(
1060
+ hidden_states=sample,
1061
+ temb=emb,
1062
+ encoder_hidden_states=encoder_hidden_states,
1063
+ attention_mask=attention_mask,
1064
+ cross_attention_kwargs=cross_attention_kwargs,
1065
+ encoder_attention_mask=encoder_attention_mask,
1066
+ **additional_residuals,
1067
+ )
1068
+ else:
1069
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1070
+
1071
+ if is_adapter and len(down_block_additional_residuals) > 0:
1072
+ sample += down_block_additional_residuals.pop(0)
1073
+
1074
+ down_block_res_samples += res_samples
1075
+
1076
+ if is_controlnet:
1077
+ new_down_block_res_samples = ()
1078
+
1079
+ for down_block_res_sample, down_block_additional_residual in zip(
1080
+ down_block_res_samples, down_block_additional_residuals
1081
+ ):
1082
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1083
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1084
+
1085
+ down_block_res_samples = new_down_block_res_samples
1086
+ # print("after down: ", sample.mean(), emb.mean())
1087
+
1088
+ # 4. mid
1089
+ if self.mid_block is not None:
1090
+ sample = self.mid_block(
1091
+ sample,
1092
+ emb,
1093
+ encoder_hidden_states=encoder_hidden_states,
1094
+ attention_mask=attention_mask,
1095
+ cross_attention_kwargs=cross_attention_kwargs,
1096
+ encoder_attention_mask=encoder_attention_mask,
1097
+ )
1098
+
1099
+ if is_controlnet:
1100
+ sample = sample + mid_block_additional_residual
1101
+
1102
+ # print("after mid: ", sample.mean())
1103
+ # 5. up
1104
+ for i, upsample_block in enumerate(self.up_blocks):
1105
+ is_final_block = i == len(self.up_blocks) - 1
1106
+
1107
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1108
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1109
+
1110
+ # if we have not reached the final block and need to forward the
1111
+ # upsample size, we do it here
1112
+ if not is_final_block and forward_upsample_size:
1113
+ upsample_size = down_block_res_samples[-1].shape[2:]
1114
+
1115
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1116
+ sample = upsample_block(
1117
+ hidden_states=sample,
1118
+ temb=emb,
1119
+ res_hidden_states_tuple=res_samples,
1120
+ encoder_hidden_states=encoder_hidden_states,
1121
+ cross_attention_kwargs=cross_attention_kwargs,
1122
+ upsample_size=upsample_size,
1123
+ attention_mask=attention_mask,
1124
+ encoder_attention_mask=encoder_attention_mask,
1125
+ )
1126
+ else:
1127
+ sample = upsample_block(
1128
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1129
+ )
1130
+
1131
+ # 6. post-process
1132
+ # if self.conv_norm_out:
1133
+ # sample = self.conv_norm_out(sample)
1134
+ # sample = self.conv_act(sample)
1135
+ # sample = self.conv_out(sample)
1136
+
1137
+ if not return_dict:
1138
+ return (sample,)
1139
+
1140
+ return UNetMV2DRefOutput(sample=sample)
1141
+
1142
+ @classmethod
1143
+ def from_pretrained_2d(
1144
+ cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
1145
+ camera_embedding_type: str, num_views: int, sample_size: int,
1146
+ zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False,
1147
+ projection_class_embeddings_input_dim: int=6, joint_attention: bool = False,
1148
+ joint_attention_twice: bool = False, multiview_attention: bool = True,
1149
+ cross_domain_attention: bool = False,
1150
+ in_channels: int = 8, out_channels: int = 4, local_crossattn=False,
1151
+ **kwargs
1152
+ ):
1153
+ r"""
1154
+ Instantiate a pretrained PyTorch model from a pretrained model configuration.
1155
+
1156
+ The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
1157
+ train the model, set it back in training mode with `model.train()`.
1158
+
1159
+ Parameters:
1160
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
1161
+ Can be either:
1162
+
1163
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
1164
+ the Hub.
1165
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
1166
+ with [`~ModelMixin.save_pretrained`].
1167
+
1168
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
1169
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
1170
+ is not used.
1171
+ torch_dtype (`str` or `torch.dtype`, *optional*):
1172
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
1173
+ dtype is automatically derived from the model's weights.
1174
+ force_download (`bool`, *optional*, defaults to `False`):
1175
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1176
+ cached versions if they exist.
1177
+ resume_download (`bool`, *optional*, defaults to `False`):
1178
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
1179
+ incompletely downloaded files are deleted.
1180
+ proxies (`Dict[str, str]`, *optional*):
1181
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1182
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1183
+ output_loading_info (`bool`, *optional*, defaults to `False`):
1184
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
1185
+ local_files_only(`bool`, *optional*, defaults to `False`):
1186
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
1187
+ won't be downloaded from the Hub.
1188
+ use_auth_token (`str` or *bool*, *optional*):
1189
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
1190
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
1191
+ revision (`str`, *optional*, defaults to `"main"`):
1192
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
1193
+ allowed by Git.
1194
+ from_flax (`bool`, *optional*, defaults to `False`):
1195
+ Load the model weights from a Flax checkpoint save file.
1196
+ subfolder (`str`, *optional*, defaults to `""`):
1197
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
1198
+ mirror (`str`, *optional*):
1199
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
1200
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
1201
+ information.
1202
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
1203
+ A map that specifies where each submodule should go. It doesn't need to be defined for each
1204
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
1205
+ same device.
1206
+
1207
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
1208
+ more information about each option see [designing a device
1209
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
1210
+ max_memory (`Dict`, *optional*):
1211
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
1212
+ each GPU and the available CPU RAM if unset.
1213
+ offload_folder (`str` or `os.PathLike`, *optional*):
1214
+ The path to offload weights if `device_map` contains the value `"disk"`.
1215
+ offload_state_dict (`bool`, *optional*):
1216
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
1217
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
1218
+ when there is some disk offload.
1219
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
1220
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
1221
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
1222
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
1223
+ argument to `True` will raise an error.
1224
+ variant (`str`, *optional*):
1225
+ Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
1226
+ loading `from_flax`.
1227
+ use_safetensors (`bool`, *optional*, defaults to `None`):
1228
+ If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
1229
+ `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
1230
+ weights. If set to `False`, `safetensors` weights are not loaded.
1231
+
1232
+ <Tip>
1233
+
1234
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
1235
+ `huggingface-cli login`. You can also activate the special
1236
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
1237
+ firewalled environment.
1238
+
1239
+ </Tip>
1240
+
1241
+ Example:
1242
+
1243
+ ```py
1244
+ from diffusers import UNet2DConditionModel
1245
+
1246
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
1247
+ ```
1248
+
1249
+ If you get the error message below, you need to finetune the weights for your downstream task:
1250
+
1251
+ ```bash
1252
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
1253
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
1254
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
1255
+ ```
1256
+ """
1257
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
1258
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
1259
+ force_download = kwargs.pop("force_download", False)
1260
+ from_flax = kwargs.pop("from_flax", False)
1261
+ resume_download = kwargs.pop("resume_download", False)
1262
+ proxies = kwargs.pop("proxies", None)
1263
+ output_loading_info = kwargs.pop("output_loading_info", False)
1264
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
1265
+ use_auth_token = kwargs.pop("use_auth_token", None)
1266
+ revision = kwargs.pop("revision", None)
1267
+ torch_dtype = kwargs.pop("torch_dtype", None)
1268
+ subfolder = kwargs.pop("subfolder", None)
1269
+ device_map = kwargs.pop("device_map", None)
1270
+ max_memory = kwargs.pop("max_memory", None)
1271
+ offload_folder = kwargs.pop("offload_folder", None)
1272
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
1273
+ variant = kwargs.pop("variant", None)
1274
+ use_safetensors = kwargs.pop("use_safetensors", None)
1275
+
1276
+ # if use_safetensors and not is_safetensors_available():
1277
+ # raise ValueError(
1278
+ # "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
1279
+ # )
1280
+
1281
+ allow_pickle = False
1282
+ if use_safetensors is None:
1283
+ # use_safetensors = is_safetensors_available()
1284
+ use_safetensors = False
1285
+ allow_pickle = True
1286
+
1287
+ if device_map is not None and not is_accelerate_available():
1288
+ raise NotImplementedError(
1289
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
1290
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
1291
+ )
1292
+
1293
+ # Check if we can handle device_map and dispatching the weights
1294
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
1295
+ raise NotImplementedError(
1296
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
1297
+ " `device_map=None`."
1298
+ )
1299
+
1300
+ # Load config if we don't provide a configuration
1301
+ config_path = pretrained_model_name_or_path
1302
+
1303
+ user_agent = {
1304
+ "diffusers": __version__,
1305
+ "file_type": "model",
1306
+ "framework": "pytorch",
1307
+ }
1308
+
1309
+ # load config
1310
+ config, unused_kwargs, commit_hash = cls.load_config(
1311
+ config_path,
1312
+ cache_dir=cache_dir,
1313
+ return_unused_kwargs=True,
1314
+ return_commit_hash=True,
1315
+ force_download=force_download,
1316
+ resume_download=resume_download,
1317
+ proxies=proxies,
1318
+ local_files_only=local_files_only,
1319
+ use_auth_token=use_auth_token,
1320
+ revision=revision,
1321
+ subfolder=subfolder,
1322
+ device_map=device_map,
1323
+ max_memory=max_memory,
1324
+ offload_folder=offload_folder,
1325
+ offload_state_dict=offload_state_dict,
1326
+ user_agent=user_agent,
1327
+ **kwargs,
1328
+ )
1329
+
1330
+ # modify config
1331
+ config["_class_name"] = cls.__name__
1332
+ config['in_channels'] = in_channels
1333
+ config['out_channels'] = out_channels
1334
+ config['sample_size'] = sample_size # training resolution
1335
+ config['num_views'] = num_views
1336
+ config['joint_attention'] = joint_attention
1337
+ config['joint_attention_twice'] = joint_attention_twice
1338
+ config['multiview_attention'] = multiview_attention
1339
+ config['cross_domain_attention'] = cross_domain_attention
1340
+ config["down_block_types"] = [
1341
+ "CrossAttnDownBlockMV2D",
1342
+ "CrossAttnDownBlockMV2D",
1343
+ "CrossAttnDownBlockMV2D",
1344
+ "DownBlock2D"
1345
+ ]
1346
+ config['mid_block_type'] = "UNetMidBlockMV2DCrossAttn"
1347
+ config["up_block_types"] = [
1348
+ "UpBlock2D",
1349
+ "CrossAttnUpBlockMV2D",
1350
+ "CrossAttnUpBlockMV2D",
1351
+ "CrossAttnUpBlockMV2D"
1352
+ ]
1353
+ config['class_embed_type'] = 'projection'
1354
+ if camera_embedding_type == 'e_de_da_sincos':
1355
+ config['projection_class_embeddings_input_dim'] = projection_class_embeddings_input_dim # default 6
1356
+ else:
1357
+ raise NotImplementedError
1358
+
1359
+ # load model
1360
+ model_file = None
1361
+ if from_flax:
1362
+ raise NotImplementedError
1363
+ else:
1364
+ if use_safetensors:
1365
+ try:
1366
+ model_file = _get_model_file(
1367
+ pretrained_model_name_or_path,
1368
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
1369
+ cache_dir=cache_dir,
1370
+ force_download=force_download,
1371
+ resume_download=resume_download,
1372
+ proxies=proxies,
1373
+ local_files_only=local_files_only,
1374
+ use_auth_token=use_auth_token,
1375
+ revision=revision,
1376
+ subfolder=subfolder,
1377
+ user_agent=user_agent,
1378
+ commit_hash=commit_hash,
1379
+ )
1380
+ except IOError as e:
1381
+ if not allow_pickle:
1382
+ raise e
1383
+ pass
1384
+ if model_file is None:
1385
+ model_file = _get_model_file(
1386
+ pretrained_model_name_or_path,
1387
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
1388
+ cache_dir=cache_dir,
1389
+ force_download=force_download,
1390
+ resume_download=resume_download,
1391
+ proxies=proxies,
1392
+ local_files_only=local_files_only,
1393
+ use_auth_token=use_auth_token,
1394
+ revision=revision,
1395
+ subfolder=subfolder,
1396
+ user_agent=user_agent,
1397
+ commit_hash=commit_hash,
1398
+ )
1399
+
1400
+ model = cls.from_config(config, **unused_kwargs)
1401
+ if local_crossattn:
1402
+ unet_lora_attn_procs = dict()
1403
+ for name, _ in model.attn_processors.items():
1404
+ if not name.endswith("attn1.processor"):
1405
+ default_attn_proc = AttnProcessor()
1406
+ elif is_xformers_available():
1407
+ default_attn_proc = XFormersMVAttnProcessor()
1408
+ else:
1409
+ default_attn_proc = MVAttnProcessor()
1410
+ unet_lora_attn_procs[name] = ReferenceOnlyAttnProc(
1411
+ default_attn_proc, enabled=name.endswith("attn1.processor"), name=name
1412
+ )
1413
+ model.set_attn_processor(unet_lora_attn_procs)
1414
+ state_dict = load_state_dict(model_file, variant=variant)
1415
+ model._convert_deprecated_attention_blocks(state_dict)
1416
+
1417
+ conv_in_weight = state_dict['conv_in.weight']
1418
+ conv_out_weight = state_dict['conv_out.weight']
1419
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d(
1420
+ model,
1421
+ state_dict,
1422
+ model_file,
1423
+ pretrained_model_name_or_path,
1424
+ ignore_mismatched_sizes=True,
1425
+ )
1426
+ if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]):
1427
+ # initialize from the original SD structure
1428
+ model.conv_in.weight.data[:,:4] = conv_in_weight
1429
+
1430
+ # whether to place all zero to new layers?
1431
+ if zero_init_conv_in:
1432
+ model.conv_in.weight.data[:,4:] = 0.
1433
+
1434
+ if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]):
1435
+ # initialize from the original SD structure
1436
+ model.conv_out.weight.data[:,:4] = conv_out_weight
1437
+ if out_channels == 8: # copy for the last 4 channels
1438
+ model.conv_out.weight.data[:, 4:] = conv_out_weight
1439
+
1440
+ if zero_init_camera_projection:
1441
+ for p in model.class_embedding.parameters():
1442
+ torch.nn.init.zeros_(p)
1443
+
1444
+ loading_info = {
1445
+ "missing_keys": missing_keys,
1446
+ "unexpected_keys": unexpected_keys,
1447
+ "mismatched_keys": mismatched_keys,
1448
+ "error_msgs": error_msgs,
1449
+ }
1450
+
1451
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
1452
+ raise ValueError(
1453
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
1454
+ )
1455
+ elif torch_dtype is not None:
1456
+ model = model.to(torch_dtype)
1457
+
1458
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
1459
+
1460
+ # Set model in evaluation mode to deactivate DropOut modules by default
1461
+ model.eval()
1462
+ if output_loading_info:
1463
+ return model, loading_info
1464
+
1465
+ return model
1466
+
1467
+ @classmethod
1468
+ def _load_pretrained_model_2d(
1469
+ cls,
1470
+ model,
1471
+ state_dict,
1472
+ resolved_archive_file,
1473
+ pretrained_model_name_or_path,
1474
+ ignore_mismatched_sizes=False,
1475
+ ):
1476
+ # Retrieve missing & unexpected_keys
1477
+ model_state_dict = model.state_dict()
1478
+ loaded_keys = list(state_dict.keys())
1479
+
1480
+ expected_keys = list(model_state_dict.keys())
1481
+
1482
+ original_loaded_keys = loaded_keys
1483
+
1484
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
1485
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
1486
+
1487
+ # Make sure we are able to load base models as well as derived models (with heads)
1488
+ model_to_load = model
1489
+
1490
+ def _find_mismatched_keys(
1491
+ state_dict,
1492
+ model_state_dict,
1493
+ loaded_keys,
1494
+ ignore_mismatched_sizes,
1495
+ ):
1496
+ mismatched_keys = []
1497
+ if ignore_mismatched_sizes:
1498
+ for checkpoint_key in loaded_keys:
1499
+ model_key = checkpoint_key
1500
+
1501
+ if (
1502
+ model_key in model_state_dict
1503
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
1504
+ ):
1505
+ mismatched_keys.append(
1506
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
1507
+ )
1508
+ del state_dict[checkpoint_key]
1509
+ return mismatched_keys
1510
+
1511
+ if state_dict is not None:
1512
+ # Whole checkpoint
1513
+ mismatched_keys = _find_mismatched_keys(
1514
+ state_dict,
1515
+ model_state_dict,
1516
+ original_loaded_keys,
1517
+ ignore_mismatched_sizes,
1518
+ )
1519
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
1520
+
1521
+ if len(error_msgs) > 0:
1522
+ error_msg = "\n\t".join(error_msgs)
1523
+ if "size mismatch" in error_msg:
1524
+ error_msg += (
1525
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
1526
+ )
1527
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
1528
+
1529
+ if len(unexpected_keys) > 0:
1530
+ logger.warning(
1531
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
1532
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
1533
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
1534
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
1535
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
1536
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
1537
+ " identical (initializing a BertForSequenceClassification model from a"
1538
+ " BertForSequenceClassification model)."
1539
+ )
1540
+ else:
1541
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
1542
+ if len(missing_keys) > 0:
1543
+ logger.warning(
1544
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1545
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
1546
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
1547
+ )
1548
+ elif len(mismatched_keys) == 0:
1549
+ logger.info(
1550
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
1551
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
1552
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
1553
+ " without further training."
1554
+ )
1555
+ if len(mismatched_keys) > 0:
1556
+ mismatched_warning = "\n".join(
1557
+ [
1558
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
1559
+ for key, shape1, shape2 in mismatched_keys
1560
+ ]
1561
+ )
1562
+ logger.warning(
1563
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1564
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
1565
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
1566
+ " able to use it for predictions and inference."
1567
+ )
1568
+
1569
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
1570
+
2D_Stage/tuneavideo/pipelines/pipeline_tuneavideo.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
2
+
3
+ import tqdm
4
+
5
+ import inspect
6
+ from typing import Callable, List, Optional, Union
7
+ from dataclasses import dataclass
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+ from diffusers.utils import is_accelerate_available
13
+ from packaging import version
14
+ from transformers import CLIPTextModel, CLIPTokenizer
15
+ import torchvision.transforms.functional as TF
16
+
17
+ from diffusers.configuration_utils import FrozenDict
18
+ from diffusers.models import AutoencoderKL
19
+ from diffusers import DiffusionPipeline
20
+ from diffusers.schedulers import (
21
+ DDIMScheduler,
22
+ DPMSolverMultistepScheduler,
23
+ EulerAncestralDiscreteScheduler,
24
+ EulerDiscreteScheduler,
25
+ LMSDiscreteScheduler,
26
+ PNDMScheduler,
27
+ )
28
+ from diffusers.utils import deprecate, logging, BaseOutput
29
+
30
+ from einops import rearrange
31
+
32
+ from ..models.unet import UNet3DConditionModel
33
+ from torchvision.transforms import InterpolationMode
34
+
35
+ import ipdb
36
+
37
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
+
39
+
40
+ @dataclass
41
+ class TuneAVideoPipelineOutput(BaseOutput):
42
+ videos: Union[torch.Tensor, np.ndarray]
43
+
44
+
45
+ class TuneAVideoPipeline(DiffusionPipeline):
46
+ _optional_components = []
47
+
48
+ def __init__(
49
+ self,
50
+ vae: AutoencoderKL,
51
+ text_encoder: CLIPTextModel,
52
+ tokenizer: CLIPTokenizer,
53
+ unet: UNet3DConditionModel,
54
+
55
+ scheduler: Union[
56
+ DDIMScheduler,
57
+ PNDMScheduler,
58
+ LMSDiscreteScheduler,
59
+ EulerDiscreteScheduler,
60
+ EulerAncestralDiscreteScheduler,
61
+ DPMSolverMultistepScheduler,
62
+ ],
63
+ ref_unet = None,
64
+ feature_extractor=None,
65
+ image_encoder=None
66
+ ):
67
+ super().__init__()
68
+ self.ref_unet = ref_unet
69
+ self.feature_extractor = feature_extractor
70
+ self.image_encoder = image_encoder
71
+
72
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
73
+ deprecation_message = (
74
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
75
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
76
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
77
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
78
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
79
+ " file"
80
+ )
81
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
82
+ new_config = dict(scheduler.config)
83
+ new_config["steps_offset"] = 1
84
+ scheduler._internal_dict = FrozenDict(new_config)
85
+
86
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
87
+ deprecation_message = (
88
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
89
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
90
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
91
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
92
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
93
+ )
94
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
95
+ new_config = dict(scheduler.config)
96
+ new_config["clip_sample"] = False
97
+ scheduler._internal_dict = FrozenDict(new_config)
98
+
99
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
100
+ version.parse(unet.config._diffusers_version).base_version
101
+ ) < version.parse("0.9.0.dev0")
102
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
103
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
104
+ deprecation_message = (
105
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
106
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
107
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
108
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
109
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
110
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
111
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
112
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
113
+ " the `unet/config.json` file"
114
+ )
115
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
116
+ new_config = dict(unet.config)
117
+ new_config["sample_size"] = 64
118
+ unet._internal_dict = FrozenDict(new_config)
119
+
120
+ self.register_modules(
121
+ vae=vae,
122
+ text_encoder=text_encoder,
123
+ tokenizer=tokenizer,
124
+ unet=unet,
125
+ scheduler=scheduler,
126
+ )
127
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
128
+
129
+ def enable_vae_slicing(self):
130
+ self.vae.enable_slicing()
131
+
132
+ def disable_vae_slicing(self):
133
+ self.vae.disable_slicing()
134
+
135
+ def enable_sequential_cpu_offload(self, gpu_id=0):
136
+ if is_accelerate_available():
137
+ from accelerate import cpu_offload
138
+ else:
139
+ raise ImportError("Please install accelerate via `pip install accelerate`")
140
+
141
+ device = torch.device(f"cuda:{gpu_id}")
142
+
143
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
144
+ if cpu_offloaded_model is not None:
145
+ cpu_offload(cpu_offloaded_model, device)
146
+
147
+
148
+ @property
149
+ def _execution_device(self):
150
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
151
+ return self.device
152
+ for module in self.unet.modules():
153
+ if (
154
+ hasattr(module, "_hf_hook")
155
+ and hasattr(module._hf_hook, "execution_device")
156
+ and module._hf_hook.execution_device is not None
157
+ ):
158
+ return torch.device(module._hf_hook.execution_device)
159
+ return self.device
160
+
161
+ def _encode_image(self, image_pil, device, num_images_per_prompt, do_classifier_free_guidance, img_proj=None):
162
+ dtype = next(self.image_encoder.parameters()).dtype
163
+
164
+ # image_pt = self.feature_extractor(images=image_pil, return_tensors="pt").pixel_values
165
+ # image_pt = image_pt.to(device=device, dtype=dtype)
166
+ # image_embeddings = self.image_encoder(image_pt).image_embeds
167
+ # image_embeddings = image_embeddings.unsqueeze(1)
168
+
169
+ # # image encoding
170
+ clip_image_mean = torch.as_tensor(self.feature_extractor.image_mean)[:,None,None].to(device, dtype=torch.float32)
171
+ clip_image_std = torch.as_tensor(self.feature_extractor.image_std)[:,None,None].to(device, dtype=torch.float32)
172
+ imgs_in_proc = TF.resize(image_pil, (self.feature_extractor.crop_size['height'], self.feature_extractor.crop_size['width']), interpolation=InterpolationMode.BICUBIC)
173
+ # do the normalization in float32 to preserve precision
174
+ imgs_in_proc = ((imgs_in_proc.float() - clip_image_mean) / clip_image_std).to(dtype)
175
+ if img_proj is None:
176
+ # (B*Nv, 1, 768)
177
+ image_embeddings = self.image_encoder(imgs_in_proc).image_embeds.unsqueeze(1)
178
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
179
+ # Note: repeat differently from official pipelines
180
+ # B1B2B3B4 -> B1B2B3B4B1B2B3B4
181
+ bs_embed, seq_len, _ = image_embeddings.shape
182
+ image_embeddings = image_embeddings.repeat(num_images_per_prompt, 1, 1)
183
+ if do_classifier_free_guidance:
184
+ negative_prompt_embeds = torch.zeros_like(image_embeddings)
185
+
186
+ # For classifier free guidance, we need to do two forward passes.
187
+ # Here we concatenate the unconditional and text embeddings into a single batch
188
+ # to avoid doing two forward passes
189
+ image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
190
+ else:
191
+ if do_classifier_free_guidance:
192
+ negative_image_proc = torch.zeros_like(imgs_in_proc)
193
+
194
+ # For classifier free guidance, we need to do two forward passes.
195
+ # Here we concatenate the unconditional and text embeddings into a single batch
196
+ # to avoid doing two forward passes
197
+ imgs_in_proc = torch.cat([negative_image_proc, imgs_in_proc])
198
+
199
+ image_embeds = image_encoder(imgs_in_proc, output_hidden_states=True).hidden_states[-2]
200
+ image_embeddings = img_proj(image_embeds)
201
+
202
+ # image_embeddings_unet = rearrange(image_embeddings_unet, 'B Nv d c -> (B Nv) d c')
203
+
204
+ # image_pt = torch.stack([TF.to_tensor(img) for img in image_pil], dim=0).to(device)
205
+ # image_pil = image_pil * 2.0 - 1.0
206
+ image_latents = self.vae.encode(image_pil* 2.0 - 1.0).latent_dist.mode() * self.vae.config.scaling_factor
207
+
208
+ # Note: repeat differently from official pipelines
209
+ # B1B2B3B4 -> B1B2B3B4B1B2B3B4
210
+ image_latents = image_latents.repeat(num_images_per_prompt, 1, 1, 1)
211
+
212
+ # if do_classifier_free_guidance:
213
+ # image_latents = torch.cat([torch.zeros_like(image_latents), image_latents])
214
+
215
+ return image_embeddings, image_latents
216
+
217
+ def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
218
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
219
+
220
+ text_inputs = self.tokenizer(
221
+ prompt,
222
+ padding="max_length",
223
+ max_length=self.tokenizer.model_max_length,
224
+ truncation=True,
225
+ return_tensors="pt",
226
+ )
227
+ text_input_ids = text_inputs.input_ids
228
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
229
+
230
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
231
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
232
+ logger.warning(
233
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
234
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
235
+ )
236
+
237
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
238
+ attention_mask = text_inputs.attention_mask.to(device)
239
+ else:
240
+ attention_mask = None
241
+
242
+ text_embeddings = self.text_encoder(
243
+ text_input_ids.to(device),
244
+ attention_mask=attention_mask,
245
+ )
246
+ text_embeddings = text_embeddings[0]
247
+
248
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
249
+ bs_embed, seq_len, _ = text_embeddings.shape
250
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
251
+ text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
252
+
253
+ # get unconditional embeddings for classifier free guidance
254
+ if do_classifier_free_guidance:
255
+ uncond_tokens: List[str]
256
+ if negative_prompt is None:
257
+ uncond_tokens = [""] * batch_size
258
+ elif type(prompt) is not type(negative_prompt):
259
+ raise TypeError(
260
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
261
+ f" {type(prompt)}."
262
+ )
263
+ elif isinstance(negative_prompt, str):
264
+ uncond_tokens = [negative_prompt]
265
+ elif batch_size != len(negative_prompt):
266
+ raise ValueError(
267
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
268
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
269
+ " the batch size of `prompt`."
270
+ )
271
+ else:
272
+ uncond_tokens = negative_prompt
273
+
274
+ max_length = text_input_ids.shape[-1]
275
+ uncond_input = self.tokenizer(
276
+ uncond_tokens,
277
+ padding="max_length",
278
+ max_length=max_length,
279
+ truncation=True,
280
+ return_tensors="pt",
281
+ )
282
+
283
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
284
+ attention_mask = uncond_input.attention_mask.to(device)
285
+ else:
286
+ attention_mask = None
287
+
288
+ uncond_embeddings = self.text_encoder(
289
+ uncond_input.input_ids.to(device),
290
+ attention_mask=attention_mask,
291
+ )
292
+ uncond_embeddings = uncond_embeddings[0]
293
+
294
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
295
+ seq_len = uncond_embeddings.shape[1]
296
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
297
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
298
+
299
+ # For classifier free guidance, we need to do two forward passes.
300
+ # Here we concatenate the unconditional and text embeddings into a single batch
301
+ # to avoid doing two forward passes
302
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
303
+
304
+ return text_embeddings
305
+
306
+ def decode_latents(self, latents):
307
+ video_length = latents.shape[2]
308
+ latents = 1 / 0.18215 * latents
309
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
310
+ video = self.vae.decode(latents).sample
311
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
312
+ video = (video / 2 + 0.5).clamp(0, 1)
313
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
314
+ video = video.cpu().float().numpy()
315
+ return video
316
+
317
+ def prepare_extra_step_kwargs(self, generator, eta):
318
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
319
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
320
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
321
+ # and should be between [0, 1]
322
+
323
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
324
+ extra_step_kwargs = {}
325
+ if accepts_eta:
326
+ extra_step_kwargs["eta"] = eta
327
+
328
+ # check if the scheduler accepts generator
329
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
330
+ if accepts_generator:
331
+ extra_step_kwargs["generator"] = generator
332
+ return extra_step_kwargs
333
+
334
+ def check_inputs(self, prompt, height, width, callback_steps):
335
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
336
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
337
+
338
+ if height % 8 != 0 or width % 8 != 0:
339
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
340
+
341
+ if (callback_steps is None) or (
342
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
343
+ ):
344
+ raise ValueError(
345
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
346
+ f" {type(callback_steps)}."
347
+ )
348
+
349
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
350
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
351
+ if isinstance(generator, list) and len(generator) != batch_size:
352
+ raise ValueError(
353
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
354
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
355
+ )
356
+
357
+ if latents is None:
358
+ rand_device = "cpu" if device.type == "mps" else device
359
+
360
+ if isinstance(generator, list):
361
+ shape = (1,) + shape[1:]
362
+ latents = [
363
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
364
+ for i in range(batch_size)
365
+ ]
366
+ latents = torch.cat(latents, dim=0).to(device)
367
+ else:
368
+ latents = torch.randn(shape, dtype=dtype).to(device)
369
+ else:
370
+ if latents.shape != shape:
371
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
372
+ latents = latents.to(device)
373
+
374
+ # scale the initial noise by the standard deviation required by the scheduler
375
+ latents = latents * self.scheduler.init_noise_sigma
376
+ return latents
377
+
378
+ @torch.no_grad()
379
+ def __call__(
380
+ self,
381
+ prompt: Union[str, List[str]],
382
+ image: Union[str, List[str]],
383
+ video_length: Optional[int],
384
+ height: Optional[int] = None,
385
+ width: Optional[int] = None,
386
+ num_inference_steps: int = 50,
387
+ guidance_scale: float = 7.5,
388
+ negative_prompt: Optional[Union[str, List[str]]] = None,
389
+ num_videos_per_prompt: Optional[int] = 1,
390
+ eta: float = 0.0,
391
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
392
+ latents: Optional[torch.FloatTensor] = None,
393
+ output_type: Optional[str] = "tensor",
394
+ return_dict: bool = True,
395
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
396
+ callback_steps: Optional[int] = 1,
397
+ camera_matrixs = None,
398
+ class_labels = None,
399
+ prompt_ids = None,
400
+ unet_condition_type = None,
401
+ pose_guider = None,
402
+ pose_image = None,
403
+ img_proj=None,
404
+ use_noise=True,
405
+ use_shifted_noise=False,
406
+ rescale = 0.7,
407
+ **kwargs,
408
+ ):
409
+ # Default height and width to unet
410
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
411
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
412
+
413
+ # Check inputs. Raise error if not correct
414
+ self.check_inputs(prompt, height, width, callback_steps)
415
+ if isinstance(image, list):
416
+ batch_size = len(image)
417
+ else:
418
+ batch_size = image.shape[0]
419
+ # assert batch_size >= video_length and batch_size % video_length == 0
420
+ # Define call parameters
421
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
422
+ device = self._execution_device
423
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
424
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
425
+ # corresponds to doing no classifier free guidance.
426
+ do_classifier_free_guidance = guidance_scale > 1.0
427
+
428
+ # 3. Encode input image
429
+ # if isinstance(image, list):
430
+ # image_pil = image
431
+ # elif isinstance(image, torch.Tensor):
432
+ # image_pil = [TF.to_pil_image(image[i]) for i in range(image.shape[0])]
433
+ # encode input reference image
434
+ image_embeddings, image_latents = self._encode_image(image, device, num_videos_per_prompt, do_classifier_free_guidance, img_proj=img_proj) #torch.Size([64, 1, 768]) torch.Size([64, 4, 32, 32])
435
+ image_latents = rearrange(image_latents, "(b f) c h w -> b c f h w", f=1) #torch.Size([64, 4, 1, 32, 32])
436
+
437
+ # Encode input prompt_id
438
+ # encoder_hidden_states = self.text_encoder(prompt_ids)[0] #torch.Size([32, 77, 768])
439
+
440
+ # Encode input prompt
441
+ text_embeddings = self._encode_prompt( #torch.Size([64, 77, 768])
442
+ prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
443
+ )
444
+
445
+ # Prepare timesteps
446
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
447
+ timesteps = self.scheduler.timesteps
448
+
449
+ # Prepare latent variables
450
+ num_channels_latents = self.unet.in_channels
451
+ latents = self.prepare_latents( #torch.Size([32, 4, 4, 32, 32])
452
+ batch_size * num_videos_per_prompt,
453
+ num_channels_latents,
454
+ video_length,
455
+ height,
456
+ width,
457
+ text_embeddings.dtype,
458
+ device,
459
+ generator,
460
+ latents,
461
+ )
462
+ latents_dtype = latents.dtype
463
+ # import ipdb
464
+ # ipdb.set_trace()
465
+ # Prepare extra step kwargs.
466
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
467
+ # prepare camera_matrix
468
+ if camera_matrixs is not None:
469
+ camera_matrixs = torch.cat([camera_matrixs] * 2) if do_classifier_free_guidance else camera_matrixs #(64, 4, 12)
470
+ # Denoising loop
471
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
472
+ if pose_guider is not None:
473
+ if len(pose_image.shape) == 5:
474
+ pose_embeds = pose_guider(rearrange(pose_image, "b f c h w -> (b f) c h w"))
475
+ pose_embeds = rearrange(pose_embeds, "(b f) c h w-> b c f h w ", f=video_length)
476
+ else:
477
+ pose_embeds = pose_guider(pose_image).unsqueeze(0)
478
+ pose_embeds = torch.cat([pose_embeds]*2, dim=0)
479
+ # import ipdb
480
+ # ipdb.set_trace()
481
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
482
+ for i, t in enumerate(tqdm.tqdm(timesteps)):
483
+ # expand the latents if we are doing classifier free guidance
484
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
485
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
486
+ if pose_guider is not None:
487
+ latent_model_input = latent_model_input + pose_embeds
488
+
489
+ noise_cond = torch.randn_like(image_latents)
490
+ if use_noise:
491
+ cond_latents = self.scheduler.add_noise(image_latents, noise_cond, t)
492
+ else:
493
+ cond_latents = image_latents
494
+ cond_latent_model_input = torch.cat([cond_latents] * 2) if do_classifier_free_guidance else cond_latents
495
+ cond_latent_model_input = self.scheduler.scale_model_input(cond_latent_model_input, t)
496
+
497
+ # predict the noise residual
498
+ # ref text condition
499
+ ref_dict = {}
500
+ if self.ref_unet is not None:
501
+ noise_pred_cond = self.ref_unet(
502
+ cond_latent_model_input, #torch.Size([64, 4, 1, 32, 32])
503
+ t, #torch.Size([32])
504
+ encoder_hidden_states=text_embeddings.to(torch.float32), #torch.Size([64, 77, 768])
505
+ cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict)
506
+ ).sample.to(dtype=latents_dtype)
507
+
508
+ # if torch.isnan(noise_pred_cond).any():
509
+ # ipdb.set_trace()
510
+ # Predict the noise residual and compute loss
511
+ # model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, camera_matrixs).sample
512
+ # unet
513
+ #text condition for unet
514
+ text_embeddings_unet = text_embeddings.unsqueeze(1).repeat(1,latents.shape[2],1,1)
515
+ text_embeddings_unet = rearrange(text_embeddings_unet, 'B Nv d c -> (B Nv) d c')
516
+ #image condition for unet
517
+ image_embeddings_unet = image_embeddings.unsqueeze(1).repeat(1,latents.shape[2],1, 1)
518
+ image_embeddings_unet = rearrange(image_embeddings_unet, 'B Nv d c -> (B Nv) d c')
519
+
520
+ if unet_condition_type == 'text':
521
+ encoder_hidden_states_unet_cond = text_embeddings_unet
522
+ elif unet_condition_type == 'image':
523
+ encoder_hidden_states_unet_cond = image_embeddings_unet
524
+ else:
525
+ raise('need unet_condition_type')
526
+
527
+ if self.ref_unet is not None:
528
+ noise_pred = self.unet(
529
+ latent_model_input.to(torch.float32), #torch.Size([64, 4, 4, 32, 32])
530
+ t,
531
+ encoder_hidden_states=encoder_hidden_states_unet_cond.to(torch.float32),
532
+ camera_matrixs=camera_matrixs.to(torch.float32), #torch.Size([64, 4, 12])
533
+ cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=do_classifier_free_guidance)
534
+ # cross_attention_kwargs=dict(mode="n", ref_dict=ref_dict, is_cfg_guidance=do_classifier_free_guidance)
535
+ ).sample.to(dtype=latents_dtype)
536
+ else:
537
+ noise_pred = self.unet(
538
+ latent_model_input.to(torch.float32), #torch.Size([64, 4, 4, 32, 32])
539
+ t,
540
+ encoder_hidden_states=encoder_hidden_states_unet_cond.to(torch.float32),
541
+ camera_matrixs=camera_matrixs.to(torch.float32), #torch.Size([64, 4, 12])
542
+ # cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=do_classifier_free_guidance)
543
+ cross_attention_kwargs=dict(mode="n", ref_dict=ref_dict, is_cfg_guidance=do_classifier_free_guidance)
544
+ ).sample.to(dtype=latents_dtype)
545
+ # perform guidance
546
+ if do_classifier_free_guidance:
547
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
548
+ if use_shifted_noise:
549
+ # Apply regular classifier-free guidance.
550
+ cfg = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
551
+ # Calculate standard deviations.
552
+ std_pos = noise_pred_text.std([1,2,3], keepdim=True)
553
+ std_cfg = cfg.std([1,2,3], keepdim=True)
554
+ # Apply guidance rescale with fused operations.
555
+ factor = std_pos / std_cfg
556
+ factor = rescale * factor + (1 - rescale)
557
+ noise_pred = cfg * factor
558
+ else:
559
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
560
+ # noise_pred_uncond_, noise_pred_text_ = noise_pred_cond.chunk(2)
561
+ # noise_pred_cond = noise_pred_uncond_ + guidance_scale * (noise_pred_text_ - noise_pred_uncond_)
562
+
563
+ # compute the previous noisy sample x_t -> x_t-1
564
+ noise_pred = rearrange(noise_pred, "(b f) c h w -> b c f h w", f=video_length)
565
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
566
+ # noise_pred_cond = rearrange(noise_pred_cond, "(b f) c h w -> b c f h w", f=1)
567
+ # cond_latents = self.scheduler.step(noise_pred_cond, t, cond_latents, **extra_step_kwargs).prev_sample
568
+
569
+ # call the callback, if provided
570
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
571
+ progress_bar.update()
572
+ if callback is not None and i % callback_steps == 0:
573
+ callback(i, t, latents)
574
+
575
+ # Post-processing
576
+ video = self.decode_latents(latents)
577
+
578
+ # Convert to tensor
579
+ if output_type == "tensor":
580
+ video = torch.from_numpy(video)
581
+
582
+ if not return_dict:
583
+ return video
584
+
585
+ return TuneAVideoPipelineOutput(videos=video)
2D_Stage/tuneavideo/util.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import imageio
3
+ import numpy as np
4
+ from typing import Union
5
+ import cv2
6
+ import torch
7
+ import torchvision
8
+
9
+ from tqdm import tqdm
10
+ from einops import rearrange
11
+
12
+ def shifted_noise(betas, image_d=512, noise_d=256, shifted_noise=True):
13
+ alphas = 1 - betas
14
+ alphas_bar = torch.cumprod(alphas, dim=0)
15
+ d = (image_d / noise_d) ** 2
16
+ if shifted_noise:
17
+ alphas_bar = alphas_bar / (d - (d - 1) * alphas_bar)
18
+ alphas_bar_sqrt = torch.sqrt(alphas_bar)
19
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
20
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
21
+ # Shift so last timestep is zero.
22
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
23
+ # Scale so first timestep is back to old value.
24
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (
25
+ alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
26
+
27
+ # Convert alphas_bar_sqrt to betas
28
+ alphas_bar = alphas_bar_sqrt ** 2
29
+ alphas = alphas_bar[1:] / alphas_bar[:-1]
30
+ alphas = torch.cat([alphas_bar[0:1], alphas])
31
+ betas = 1 - alphas
32
+ return betas
33
+
34
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8):
35
+ videos = rearrange(videos, "b c t h w -> t b c h w")
36
+ outputs = []
37
+ for x in videos:
38
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
39
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
40
+ if rescale:
41
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
42
+ x = (x * 255).numpy().astype(np.uint8)
43
+ outputs.append(x)
44
+
45
+ os.makedirs(os.path.dirname(path), exist_ok=True)
46
+ imageio.mimsave(path, outputs, duration=1000/fps)
47
+
48
+ def save_imgs_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8):
49
+ videos = rearrange(videos, "b c t h w -> t b c h w")
50
+ for i, x in enumerate(videos):
51
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
52
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
53
+ if rescale:
54
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
55
+ x = (x * 255).numpy().astype(np.uint8)
56
+ os.makedirs(os.path.dirname(path), exist_ok=True)
57
+ cv2.imwrite(os.path.join(path, f'view_{i}.png'), x[:,:,::-1])
58
+
59
+ def imgs_grid(videos: torch.Tensor, rescale=False, n_rows=4, fps=8):
60
+ videos = rearrange(videos, "b c t h w -> t b c h w")
61
+ image_list = []
62
+ for i, x in enumerate(videos):
63
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
64
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
65
+ if rescale:
66
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
67
+ x = (x * 255).numpy().astype(np.uint8)
68
+ # image_list.append(x[:,:,::-1])
69
+ image_list.append(x)
70
+ return image_list
71
+
72
+ # DDIM Inversion
73
+ @torch.no_grad()
74
+ def init_prompt(prompt, pipeline):
75
+ uncond_input = pipeline.tokenizer(
76
+ [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
77
+ return_tensors="pt"
78
+ )
79
+ uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
80
+ text_input = pipeline.tokenizer(
81
+ [prompt],
82
+ padding="max_length",
83
+ max_length=pipeline.tokenizer.model_max_length,
84
+ truncation=True,
85
+ return_tensors="pt",
86
+ )
87
+ text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
88
+ context = torch.cat([uncond_embeddings, text_embeddings])
89
+
90
+ return context
91
+
92
+
93
+ def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
94
+ sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
95
+ timestep, next_timestep = min(
96
+ timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
97
+ alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
98
+ alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
99
+ beta_prod_t = 1 - alpha_prod_t
100
+ next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
101
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
102
+ next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
103
+ return next_sample
104
+
105
+
106
+ def get_noise_pred_single(latents, t, context, unet):
107
+ noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
108
+ return noise_pred
109
+
110
+
111
+ @torch.no_grad()
112
+ def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
113
+ context = init_prompt(prompt, pipeline)
114
+ uncond_embeddings, cond_embeddings = context.chunk(2)
115
+ all_latent = [latent]
116
+ latent = latent.clone().detach()
117
+ for i in tqdm(range(num_inv_steps)):
118
+ t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
119
+ noise_pred = get_noise_pred_single(latent.to(torch.float32), t, cond_embeddings.to(torch.float32), pipeline.unet)
120
+ latent = next_step(noise_pred, t, latent, ddim_scheduler)
121
+ all_latent.append(latent)
122
+ return all_latent
123
+
124
+
125
+ @torch.no_grad()
126
+ def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
127
+ ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
128
+ return ddim_latents
2D_Stage/webui.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import glob
4
+
5
+ import io
6
+ import argparse
7
+ import inspect
8
+ import os
9
+ import random
10
+ from typing import Dict, Optional, Tuple
11
+ from omegaconf import OmegaConf
12
+ import numpy as np
13
+
14
+ import torch
15
+ import torch.utils.checkpoint
16
+
17
+ from accelerate.logging import get_logger
18
+ from accelerate.utils import set_seed
19
+ from diffusers import AutoencoderKL, DDIMScheduler
20
+ from diffusers.utils import check_min_version
21
+ from tqdm.auto import tqdm
22
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection
23
+ from torchvision import transforms
24
+
25
+ from tuneavideo.models.unet_mv2d_condition import UNetMV2DConditionModel
26
+ from tuneavideo.models.unet_mv2d_ref import UNetMV2DRefModel
27
+ from tuneavideo.models.PoseGuider import PoseGuider
28
+ from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline
29
+ from tuneavideo.util import shifted_noise
30
+ from einops import rearrange
31
+ import PIL
32
+ from PIL import Image
33
+ from torchvision.utils import save_image
34
+ import json
35
+ import cv2
36
+
37
+ import onnxruntime as rt
38
+ from huggingface_hub.file_download import hf_hub_download
39
+ from rm_anime_bg.cli import get_mask, SCALE
40
+
41
+ from huggingface_hub import hf_hub_download, list_repo_files
42
+
43
+ repo_id = "zjpshadow/CharacterGen"
44
+ all_files = list_repo_files(repo_id, revision="main")
45
+
46
+ for file in all_files:
47
+ if os.path.exists("../" + file):
48
+ continue
49
+ if file.startswith("2D_Stage"):
50
+ hf_hub_download(repo_id, file, local_dir="../")
51
+
52
+ class rm_bg_api:
53
+
54
+ def __init__(self, force_cpu: Optional[bool] = True):
55
+ session_infer_path = hf_hub_download(
56
+ repo_id="skytnt/anime-seg", filename="isnetis.onnx",
57
+ )
58
+ providers: list[str] = ["CPUExecutionProvider"]
59
+ if not force_cpu and "CUDAExecutionProvider" in rt.get_available_providers():
60
+ providers = ["CUDAExecutionProvider"]
61
+
62
+ self.session_infer = rt.InferenceSession(
63
+ session_infer_path, providers=providers,
64
+ )
65
+
66
+ def remove_background(
67
+ self,
68
+ imgs: list[np.ndarray],
69
+ alpha_min: float,
70
+ alpha_max: float,
71
+ ) -> list:
72
+ process_imgs = []
73
+ for img in imgs:
74
+ # CHANGE to RGB
75
+ img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
76
+ mask = get_mask(self.session_infer, img)
77
+
78
+ mask[mask < alpha_min] = 0.0 # type: ignore
79
+ mask[mask > alpha_max] = 1.0 # type: ignore
80
+
81
+ img_after = (mask * img + SCALE * (1 - mask)).astype(np.uint8) # type: ignore
82
+ mask = (mask * SCALE).astype(np.uint8) # type: ignore
83
+ img_after = np.concatenate([img_after, mask], axis=2, dtype=np.uint8)
84
+ mask = mask.repeat(3, axis=2)
85
+ process_imgs.append(Image.fromarray(img_after))
86
+ return process_imgs
87
+
88
+ check_min_version("0.24.0")
89
+
90
+ logger = get_logger(__name__, log_level="INFO")
91
+
92
+ def set_seed(seed):
93
+ random.seed(seed)
94
+ np.random.seed(seed)
95
+ torch.manual_seed(seed)
96
+ torch.cuda.manual_seed_all(seed)
97
+
98
+ def get_bg_color(bg_color):
99
+ if bg_color == 'white':
100
+ bg_color = np.array([1., 1., 1.], dtype=np.float32)
101
+ elif bg_color == 'black':
102
+ bg_color = np.array([0., 0., 0.], dtype=np.float32)
103
+ elif bg_color == 'gray':
104
+ bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32)
105
+ elif bg_color == 'random':
106
+ bg_color = np.random.rand(3)
107
+ elif isinstance(bg_color, float):
108
+ bg_color = np.array([bg_color] * 3, dtype=np.float32)
109
+ else:
110
+ raise NotImplementedError
111
+ return bg_color
112
+
113
+ def process_image(image, totensor):
114
+ if not image.mode == "RGBA":
115
+ image = image.convert("RGBA")
116
+
117
+ # Find non-transparent pixels
118
+ non_transparent = np.nonzero(np.array(image)[..., 3])
119
+ min_x, max_x = non_transparent[1].min(), non_transparent[1].max()
120
+ min_y, max_y = non_transparent[0].min(), non_transparent[0].max()
121
+ image = image.crop((min_x, min_y, max_x, max_y))
122
+
123
+ # paste to center
124
+ max_dim = max(image.width, image.height)
125
+ max_height = max_dim
126
+ max_width = int(max_dim / 3 * 2)
127
+ new_image = Image.new("RGBA", (max_width, max_height))
128
+ left = (max_width - image.width) // 2
129
+ top = (max_height - image.height) // 2
130
+ new_image.paste(image, (left, top))
131
+
132
+ image = new_image.resize((512, 768), resample=PIL.Image.BICUBIC)
133
+ image = np.array(image)
134
+ image = image.astype(np.float32) / 255.
135
+ assert image.shape[-1] == 4 # RGBA
136
+ alpha = image[..., 3:4]
137
+ bg_color = get_bg_color("gray")
138
+ image = image[..., :3] * alpha + bg_color * (1 - alpha)
139
+ # save image
140
+ # new_image = Image.fromarray((image * 255).astype(np.uint8))
141
+ # new_image.save("input.png")
142
+ return totensor(image)
143
+
144
+ class Inference_API:
145
+
146
+ def __init__(self):
147
+ self.validation_pipeline = None
148
+
149
+ @torch.no_grad()
150
+ def inference(self, input_image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer, text_encoder, pretrained_model_path, generator, validation, val_width, val_height, unet_condition_type,
151
+ pose_guider=None, use_noise=True, use_shifted_noise=False, noise_d=256, crop=False, seed=100, timestep=20):
152
+ set_seed(seed)
153
+ # Get the validation pipeline
154
+ if self.validation_pipeline is None:
155
+ noise_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
156
+ if use_shifted_noise:
157
+ print(f"enable shifted noise for {val_height} to {noise_d}")
158
+ betas = shifted_noise(noise_scheduler.betas, image_d=val_height, noise_d=noise_d)
159
+ noise_scheduler.betas = betas
160
+ noise_scheduler.alphas = 1 - betas
161
+ noise_scheduler.alphas_cumprod = torch.cumprod(noise_scheduler.alphas, dim=0)
162
+ self.validation_pipeline = TuneAVideoPipeline(
163
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, ref_unet=ref_unet,feature_extractor=feature_extractor,image_encoder=image_encoder,
164
+ scheduler=noise_scheduler
165
+ )
166
+ self.validation_pipeline.enable_vae_slicing()
167
+ self.validation_pipeline.set_progress_bar_config(disable=True)
168
+
169
+ totensor = transforms.ToTensor()
170
+
171
+ metas = json.load(open("./material/pose.json", "r"))
172
+ cameras = []
173
+ pose_images = []
174
+ input_path = "./material"
175
+ for lm in metas:
176
+ cameras.append(torch.tensor(np.array(lm[0]).reshape(4, 4).transpose(1,0)[:3, :4]).reshape(-1))
177
+ if not crop:
178
+ pose_images.append(totensor(np.asarray(Image.open(os.path.join(input_path, lm[1])).resize(
179
+ (val_height, val_width), resample=PIL.Image.BICUBIC)).astype(np.float32) / 255.))
180
+ else:
181
+ pose_image = Image.open(os.path.join(input_path, lm[1]))
182
+ crop_area = (128, 0, 640, 768)
183
+ pose_images.append(totensor(np.array(pose_image.crop(crop_area)).astype(np.float32)) / 255.)
184
+ camera_matrixs = torch.stack(cameras).unsqueeze(0).to("cuda")
185
+ pose_imgs_in = torch.stack(pose_images).to("cuda")
186
+ prompts = "high quality, best quality"
187
+ prompt_ids = tokenizer(
188
+ prompts, max_length=tokenizer.model_max_length, padding="max_length", truncation=True,
189
+ return_tensors="pt"
190
+ ).input_ids[0]
191
+
192
+ # (B*Nv, 3, H, W)
193
+ B = 1
194
+ weight_dtype = torch.bfloat16
195
+ imgs_in = process_image(input_image, totensor)
196
+ imgs_in = rearrange(imgs_in.unsqueeze(0).unsqueeze(0), "B Nv C H W -> (B Nv) C H W")
197
+
198
+ with torch.autocast("cuda", dtype=weight_dtype):
199
+ imgs_in = imgs_in.to("cuda")
200
+ # B*Nv images
201
+ out = self.validation_pipeline(prompt=prompts, image=imgs_in.to(weight_dtype), generator=generator,
202
+ num_inference_steps=timestep,
203
+ camera_matrixs=camera_matrixs.to(weight_dtype), prompt_ids=prompt_ids,
204
+ height=val_height, width=val_width, unet_condition_type=unet_condition_type,
205
+ pose_guider=None, pose_image=pose_imgs_in, use_noise=use_noise,
206
+ use_shifted_noise=use_shifted_noise, **validation).videos
207
+ out = rearrange(out, "B C f H W -> (B f) C H W", f=validation.video_length)
208
+
209
+ image_outputs = []
210
+ for bs in range(4):
211
+ img_buf = io.BytesIO()
212
+ save_image(out[bs], img_buf, format='PNG')
213
+ img_buf.seek(0)
214
+ img = Image.open(img_buf)
215
+ image_outputs.append(img)
216
+ torch.cuda.empty_cache()
217
+ return image_outputs
218
+
219
+ @torch.no_grad()
220
+ def main(
221
+ pretrained_model_path: str,
222
+ image_encoder_path: str,
223
+ ckpt_dir: str,
224
+ validation: Dict,
225
+ local_crossattn: bool = True,
226
+ unet_from_pretrained_kwargs=None,
227
+ unet_condition_type=None,
228
+ use_pose_guider=False,
229
+ use_noise=True,
230
+ use_shifted_noise=False,
231
+ noise_d=256
232
+ ):
233
+ *_, config = inspect.getargvalues(inspect.currentframe())
234
+
235
+ device = "cuda"
236
+
237
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
238
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
239
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(image_encoder_path)
240
+ feature_extractor = CLIPImageProcessor()
241
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
242
+ unet = UNetMV2DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", local_crossattn=local_crossattn, **unet_from_pretrained_kwargs)
243
+ ref_unet = UNetMV2DRefModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", local_crossattn=local_crossattn, **unet_from_pretrained_kwargs)
244
+ if use_pose_guider:
245
+ pose_guider = PoseGuider(noise_latent_channels=4).to("cuda")
246
+ else:
247
+ pose_guider = None
248
+
249
+ unet_params = torch.load(os.path.join(ckpt_dir, "pytorch_model.bin"), map_location="cpu")
250
+ if use_pose_guider:
251
+ pose_guider_params = torch.load(os.path.join(ckpt_dir, "pytorch_model_1.bin"), map_location="cpu")
252
+ ref_unet_params = torch.load(os.path.join(ckpt_dir, "pytorch_model_2.bin"), map_location="cpu")
253
+ pose_guider.load_state_dict(pose_guider_params)
254
+ else:
255
+ ref_unet_params = torch.load(os.path.join(ckpt_dir, "pytorch_model_1.bin"), map_location="cpu")
256
+ unet.load_state_dict(unet_params)
257
+ ref_unet.load_state_dict(ref_unet_params)
258
+
259
+ weight_dtype = torch.float16
260
+
261
+ text_encoder.to(device, dtype=weight_dtype)
262
+ image_encoder.to(device, dtype=weight_dtype)
263
+ vae.to(device, dtype=weight_dtype)
264
+ ref_unet.to(device, dtype=weight_dtype)
265
+ unet.to(device, dtype=weight_dtype)
266
+
267
+ vae.requires_grad_(False)
268
+ unet.requires_grad_(False)
269
+ ref_unet.requires_grad_(False)
270
+
271
+ generator = torch.Generator(device="cuda")
272
+ inferapi = Inference_API()
273
+ remove_api = rm_bg_api()
274
+ def gen4views(image, width, height, seed, timestep, remove_bg):
275
+ if remove_bg:
276
+ image = remove_api.remove_background(
277
+ imgs=[np.array(image)],
278
+ alpha_min=0.1,
279
+ alpha_max=0.9,
280
+ )[0]
281
+ return inferapi.inference(
282
+ image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer, text_encoder, pretrained_model_path,
283
+ generator, validation, width, height, unet_condition_type,
284
+ pose_guider=pose_guider, use_noise=use_noise, use_shifted_noise=use_shifted_noise, noise_d=noise_d,
285
+ crop=True, seed=seed, timestep=timestep
286
+ )
287
+
288
+ with gr.Blocks() as demo:
289
+ gr.Markdown("# [SIGGRAPH'24] CharacterGen: Efficient 3D Character Generation from Single Images with Multi-View Pose Calibration")
290
+ gr.Markdown("# 2D Stage: One Image to Four Views of Character Image")
291
+ gr.Markdown("**Please Upload the Image without background, and the pictures uploaded should preferably be full-body frontal photos.**")
292
+ with gr.Row():
293
+ with gr.Column():
294
+ img_input = gr.Image(type="pil", label="Upload Image(without background)", image_mode="RGBA", width=768, height=512)
295
+ gr.Examples(
296
+ label="Example Images",
297
+ examples=glob.glob("./material/examples/*.png"),
298
+ inputs=[img_input]
299
+ )
300
+ with gr.Row():
301
+ width_input = gr.Number(label="Width", value=512)
302
+ height_input = gr.Number(label="Height", value=768)
303
+ seed_input = gr.Number(label="Seed", value=2333)
304
+ remove_bg = gr.Checkbox(label="Remove Background (with algorithm)", value=False)
305
+ timestep = gr.Slider(minimum=10, maximum=70, step=1, value=40, label="Timesteps")
306
+ with gr.Column():
307
+ button = gr.Button(value="Generate")
308
+ output = gr.Gallery(label="4 views of Character Image")
309
+
310
+ button.click(
311
+ fn=gen4views,
312
+ inputs=[img_input, width_input, height_input, seed_input, timestep, remove_bg],
313
+ outputs=[output]
314
+ )
315
+
316
+ demo.launch()
317
+
318
+ if __name__ == "__main__":
319
+ parser = argparse.ArgumentParser()
320
+ parser.add_argument("--config", type=str, default="./configs/infer.yaml")
321
+ args = parser.parse_args()
322
+
323
+ main(**OmegaConf.load(args.config))
3D_Stage/configs/infer.yaml ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ system_cls: lrm.systems.multiview_lrm.MultiviewLRM
2
+ data:
3
+ cond_width: 504
4
+ cond_height: 504
5
+
6
+ system:
7
+ weights: ./models/lrm.ckpt
8
+
9
+ weights_ignore_modules:
10
+ - decoder.heads.density
11
+
12
+ check_train_every_n_steps: 100
13
+
14
+ camera_embedder_cls: lrm.models.camera.LinearCameraEmbedder
15
+ camera_embedder:
16
+ in_channels: 16
17
+ out_channels: 768
18
+ conditions:
19
+ - c2w_cond
20
+
21
+ # image tokenizer transforms input images to tokens
22
+ image_tokenizer_cls: lrm.models.tokenizers.image.DINOV2SingleImageTokenizer
23
+ image_tokenizer:
24
+ pretrained_model_name_or_path: "./models/base"
25
+ freeze_backbone_params: false
26
+ enable_memory_efficient_attention: true
27
+ enable_gradient_checkpointing: true
28
+ # camera modulation to the DINO transformer layers
29
+ modulation: true
30
+ modulation_zero_init: true
31
+ modulation_single_layer: true
32
+ modulation_cond_dim: ${system.camera_embedder.out_channels}
33
+
34
+ # tokenizer gives a tokenized representation for the 3D scene
35
+ # triplane tokens in this case
36
+ tokenizer_cls: lrm.models.tokenizers.triplane.TriplaneLearnablePositionalEmbedding
37
+ tokenizer:
38
+ plane_size: 32
39
+ num_channels: 512
40
+
41
+ # backbone network is a transformer that takes scene tokens (potentially with conditional image tokens)
42
+ # and outputs scene tokens of the same size
43
+ backbone_cls: lrm.models.transformers.transformer_1d.Transformer1D
44
+ backbone:
45
+ in_channels: ${system.tokenizer.num_channels}
46
+ num_attention_heads: 16
47
+ attention_head_dim: 64
48
+ num_layers: 12
49
+ cross_attention_dim: 768 # hard-code, =DINO feature dim
50
+ # camera modulation to the transformer layers
51
+ # if not needed, set norm_type=layer_norm and do not specify cond_dim_ada_norm_continuous
52
+ norm_type: "layer_norm"
53
+ enable_memory_efficient_attention: true
54
+ gradient_checkpointing: true
55
+
56
+ # post processor takes scene tokens and outputs the final scene parameters that will be used for rendering
57
+ # in this case, triplanes are upsampled and the features are condensed
58
+ post_processor_cls: lrm.models.networks.TriplaneUpsampleNetwork
59
+ post_processor:
60
+ in_channels: 512
61
+ out_channels: 80
62
+
63
+ renderer_cls: lrm.models.renderers.triplane_dmtet.TriplaneDMTetRenderer
64
+ renderer:
65
+ radius: 0.6 # slightly larger than 0.5
66
+ feature_reduction: concat
67
+ sdf_bias: -2.
68
+ tet_dir: "./load/tets/"
69
+ isosurface_resolution: 256
70
+ enable_isosurface_grid_deformation: false
71
+ sdf_activation: negative
72
+
73
+ decoder_cls: lrm.models.networks.MultiHeadMLP
74
+ decoder:
75
+ in_channels: 240 # 3 * 80
76
+ n_neurons: 64
77
+ n_hidden_layers_share: 8
78
+ heads:
79
+ - name: sdf
80
+ out_channels: 1
81
+ n_hidden_layers: 1
82
+ output_activation: null
83
+ - name: features
84
+ out_channels: 3
85
+ n_hidden_layers: 1
86
+ output_activation: null # activate in material
87
+ activation: silu
88
+ chunk_mode: deferred
89
+ chunk_size: 131072
90
+
91
+ exporter:
92
+ fmt: "obj"
93
+ #visual: "vertex"
94
+ visual: "uv"
95
+ save_uv: True
96
+ save_texture: True
97
+ uv_unwrap_method: "open3d"
98
+ output_path: "./outputs"
99
+
100
+ material_cls: lrm.models.materials.no_material.NoMaterial
101
+
102
+ background_cls: lrm.models.background.solid_color_background.SolidColorBackground
103
+ background:
104
+ color: [0.5, 0.5, 0.5]
3D_Stage/load/tets/generate_tets.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4
+ # property and proprietary rights in and to this material, related
5
+ # documentation and any modifications thereto. Any use, reproduction,
6
+ # disclosure or distribution of this material and related documentation
7
+ # without an express license agreement from NVIDIA CORPORATION or
8
+ # its affiliates is strictly prohibited.
9
+
10
+ import os
11
+
12
+ import numpy as np
13
+
14
+ """
15
+ This code segment shows how to use Quartet: https://github.com/crawforddoran/quartet,
16
+ to generate a tet grid
17
+ 1) Download, compile and run Quartet as described in the link above. Example usage `quartet meshes/cube.obj 0.5 cube_5.tet`
18
+ 2) Run the function below to generate a file `cube_32_tet.tet`
19
+ """
20
+
21
+
22
+ def generate_tetrahedron_grid_file(res=32, root=".."):
23
+ frac = 1.0 / res
24
+ command = f"cd {root}; ./quartet meshes/cube.obj {frac} meshes/cube_{res}_tet.tet -s meshes/cube_boundary_{res}.obj"
25
+ os.system(command)
26
+
27
+
28
+ """
29
+ This code segment shows how to convert from a quartet .tet file to compressed npz file
30
+ """
31
+
32
+
33
+ def convert_from_quartet_to_npz(quartetfile="cube_32_tet.tet", npzfile="32_tets"):
34
+ file1 = open(quartetfile, "r")
35
+ header = file1.readline()
36
+ numvertices = int(header.split(" ")[1])
37
+ numtets = int(header.split(" ")[2])
38
+ print(numvertices, numtets)
39
+
40
+ # load vertices
41
+ vertices = np.loadtxt(quartetfile, skiprows=1, max_rows=numvertices)
42
+ print(vertices.shape)
43
+
44
+ # load indices
45
+ indices = np.loadtxt(
46
+ quartetfile, dtype=int, skiprows=1 + numvertices, max_rows=numtets
47
+ )
48
+ print(indices.shape)
49
+
50
+ np.savez_compressed(npzfile, vertices=vertices, indices=indices)
51
+
52
+
53
+ root = "/home/gyc/quartet"
54
+ for res in [300, 350, 400]:
55
+ generate_tetrahedron_grid_file(res, root)
56
+ convert_from_quartet_to_npz(
57
+ os.path.join(root, f"meshes/cube_{res}_tet.tet"), npzfile=f"{res}_tets"
58
+ )
3D_Stage/lrm/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+
4
+ def find(cls_string):
5
+ module_string = ".".join(cls_string.split(".")[:-1])
6
+ cls_name = cls_string.split(".")[-1]
7
+ module = importlib.import_module(module_string, package=None)
8
+ cls = getattr(module, cls_name)
9
+ return cls
10
+
11
+
12
+ ### grammar sugar for logging utilities ###
13
+ import logging
14
+
15
+ logger = logging.getLogger("pytorch_lightning")
16
+
17
+ from pytorch_lightning.utilities.rank_zero import (
18
+ rank_zero_debug,
19
+ rank_zero_info,
20
+ rank_zero_only,
21
+ )
22
+
23
+ debug = rank_zero_debug
24
+ info = rank_zero_info
25
+
26
+
27
+ @rank_zero_only
28
+ def warn(*args, **kwargs):
29
+ logger.warn(*args, **kwargs)
3D_Stage/lrm/models/__init__.py ADDED
File without changes
3D_Stage/lrm/models/background/__init__.py ADDED
File without changes
3D_Stage/lrm/models/background/base.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from dataclasses import dataclass, field
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ import lrm
9
+ from ...utils.base import BaseModule
10
+ from ...utils.typing import *
11
+
12
+
13
+ class BaseBackground(BaseModule):
14
+ @dataclass
15
+ class Config(BaseModule.Config):
16
+ pass
17
+
18
+ cfg: Config
19
+
20
+ def configure(self):
21
+ pass
22
+
23
+ def forward(self, dirs: Float[Tensor, "B H W 3"]) -> Float[Tensor, "B H W Nc"]:
24
+ raise NotImplementedError
3D_Stage/lrm/models/background/solid_color_background.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from dataclasses import dataclass, field
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ import lrm
9
+ from .base import BaseBackground
10
+ from ...utils.typing import *
11
+
12
+
13
+ class SolidColorBackground(BaseBackground):
14
+ @dataclass
15
+ class Config(BaseBackground.Config):
16
+ n_output_dims: int = 3
17
+ color: Tuple = (1.0, 1.0, 1.0)
18
+ learned: bool = False
19
+ random_aug: bool = False
20
+ random_aug_prob: float = 0.5
21
+
22
+ cfg: Config
23
+
24
+ def configure(self) -> None:
25
+ self.env_color: Float[Tensor, "Nc"]
26
+ if self.cfg.learned:
27
+ self.env_color = nn.Parameter(
28
+ torch.as_tensor(self.cfg.color, dtype=torch.float32)
29
+ )
30
+ else:
31
+ self.register_buffer(
32
+ "env_color", torch.as_tensor(self.cfg.color, dtype=torch.float32)
33
+ )
34
+
35
+ def forward(
36
+ self,
37
+ dirs: Float[Tensor, "B H W Nc"],
38
+ color_spec: Optional[Float[Tensor, "Nc"]] = None,
39
+ ) -> Float[Tensor, "B H W Nc"]:
40
+ color = torch.ones(*dirs.shape[:-1], self.cfg.n_output_dims).to(dirs) * (
41
+ color_spec if color_spec is not None else self.env_color
42
+ )
43
+ if (
44
+ self.training
45
+ and self.cfg.random_aug
46
+ and random.random() < self.cfg.random_aug_prob
47
+ ):
48
+ # use random background color with probability random_aug_prob
49
+ # color = color * 0 + (
50
+ # torch.ones(*dirs.shape[:-1], self.cfg.n_output_dims).to(dirs) *
51
+ # torch.rand(self.cfg.n_output_dims).to(dirs)
52
+ # )
53
+ color = color * 0 + ( # prevent checking for unused parameters in DDP
54
+ torch.rand(dirs.shape[0], 1, 1, self.cfg.n_output_dims)
55
+ .to(dirs)
56
+ .expand(*dirs.shape[:-1], -1)
57
+ )
58
+ return color
3D_Stage/lrm/models/camera.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from ..utils.base import BaseModule
7
+ from ..utils.typing import *
8
+
9
+
10
+ class LinearCameraEmbedder(BaseModule):
11
+ @dataclass
12
+ class Config(BaseModule.Config):
13
+ in_channels: int = 0
14
+ out_channels: int = 0
15
+ conditions: List[str] = field(default_factory=list)
16
+
17
+ cfg: Config
18
+
19
+ def configure(self) -> None:
20
+ super().configure()
21
+ self.linear = nn.Linear(self.cfg.in_channels, self.cfg.out_channels)
22
+
23
+ def forward(self, **kwargs):
24
+ cond_tensors = []
25
+ for cond_name in self.cfg.conditions:
26
+ assert cond_name in kwargs
27
+ cond = kwargs[cond_name]
28
+ # cond in shape (B, Nv, ...)
29
+ cond_tensors.append(cond.view(*cond.shape[:2], -1))
30
+ cond_tensor = torch.cat(cond_tensors, dim=-1)
31
+ assert cond_tensor.shape[-1] == self.cfg.in_channels
32
+ embedding = self.linear(cond_tensor)
33
+ return embedding
3D_Stage/lrm/models/exporters/__init__.py ADDED
File without changes
3D_Stage/lrm/models/exporters/base.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import lrm
4
+ from ..renderers.base import BaseRenderer
5
+ from ...utils.base import BaseObject
6
+ from ...utils.typing import *
7
+
8
+
9
+ @dataclass
10
+ class ExporterOutput:
11
+ save_name: str
12
+ save_type: str
13
+ params: Dict[str, Any]
14
+
15
+
16
+ class Exporter(BaseObject):
17
+ @dataclass
18
+ class Config(BaseObject.Config):
19
+ save_video: bool = False
20
+
21
+ cfg: Config
22
+
23
+ def configure(self, renderer: BaseRenderer) -> None:
24
+ self.renderer = renderer
25
+
26
+ def __call__(self, *args, **kwargs) -> List[ExporterOutput]:
27
+ raise NotImplementedError
28
+
29
+
30
+ class DummyExporter(Exporter):
31
+ def __call__(self, *args, **kwargs) -> List[ExporterOutput]:
32
+ # DummyExporter does not export anything
33
+ return []
3D_Stage/lrm/models/exporters/mesh_exporter.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ import tempfile
3
+ import os
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+
9
+ import lrm
10
+ from ..renderers.base import BaseRenderer
11
+ from .base import Exporter, ExporterOutput
12
+ from ..mesh import Mesh
13
+ from ...utils.rasterize import NVDiffRasterizerContext
14
+ from ...utils.typing import *
15
+ from ...utils.misc import time_recorder as tr, time_recorder_enabled
16
+
17
+
18
+ def uv_padding_cpu(image, hole_mask, padding):
19
+ uv_padding_size = padding
20
+ inpaint_image = (
21
+ cv2.inpaint(
22
+ (image.detach().cpu().numpy() * 255).astype(np.uint8),
23
+ (hole_mask.detach().cpu().numpy() * 255).astype(np.uint8),
24
+ uv_padding_size,
25
+ cv2.INPAINT_TELEA,
26
+ )
27
+ / 255.0
28
+ )
29
+ return torch.from_numpy(inpaint_image).to(image)
30
+
31
+
32
+ def uv_padding_cvc(image, hole_mask, padding):
33
+ import cvcuda
34
+
35
+ torch_to_cvc = lambda x, layout: cvcuda.as_tensor(x, layout)
36
+ cvc_to_torch = lambda x: torch.as_tensor(x.cuda())
37
+
38
+ uv_padding_size = padding
39
+ image_cvc = torch_to_cvc((image.detach() * 255).to(torch.uint8), "HWC")
40
+ hole_mask_cvc = torch_to_cvc((hole_mask.detach() * 255).to(torch.uint8), "HW")
41
+ inpaint_image = cvcuda.inpaint(image_cvc, hole_mask_cvc, uv_padding_size)
42
+ inpaint_image = cvc_to_torch(inpaint_image) / 255.0
43
+ return inpaint_image.to(image)
44
+
45
+
46
+ def uv_padding(image, hole_mask, padding):
47
+ try:
48
+ inpaint_image = uv_padding_cvc(image, hole_mask, padding)
49
+ except:
50
+ lrm.info(f"CVCUDA not available, fallback to CPU UV padding.")
51
+ inpaint_image = uv_padding_cpu(image, hole_mask, padding)
52
+ return inpaint_image
53
+
54
+
55
+ class MeshExporter(Exporter):
56
+ @dataclass
57
+ class Config(Exporter.Config):
58
+ fmt: str = "obj" # in ['obj', 'glb']
59
+ visual: str = "uv" # in ['uv', 'vertex']
60
+ save_name: str = "model"
61
+ save_normal: bool = False
62
+ save_uv: bool = True
63
+ save_texture: bool = True
64
+ texture_size: int = 1024
65
+ texture_format: str = "jpg"
66
+ uv_unwrap_method: str = "xatlas"
67
+ xatlas_chart_options: dict = field(default_factory=dict)
68
+ xatlas_pack_options: dict = field(default_factory=dict)
69
+ smartuv_options: dict = field(default_factory=dict)
70
+ uv_padding_size: int = 2
71
+ subdivide: bool = False
72
+ post_process: bool = False
73
+ post_process_options: dict = field(default_factory=dict)
74
+ context_type: str = "gl"
75
+ output_path: str = "outputs"
76
+
77
+ cfg: Config
78
+
79
+ def configure(self, renderer: BaseRenderer) -> None:
80
+ super().configure(renderer)
81
+ self.ctx = NVDiffRasterizerContext(self.cfg.context_type, self.device)
82
+ if self.cfg.fmt == "obj-mtl":
83
+ lrm.warn(
84
+ f"fmt=obj-mtl is deprecated, please us fmt=obj and visual=uv instead."
85
+ )
86
+ self.cfg.fmt = "obj"
87
+ self.cfg.visual = "uv"
88
+
89
+ if self.cfg.fmt == "glb":
90
+ assert self.cfg.visual in [
91
+ "vertex",
92
+ "uv-blender",
93
+ ], "GLB format only supports visual=vertex and visual=uv-blender!"
94
+
95
+ def get_geometry(self, scene_code: torch.Tensor) -> Mesh:
96
+ tr.start("Surface extraction")
97
+ mesh: Mesh = self.renderer.isosurface(scene_code)
98
+ tr.end("Surface extraction")
99
+ return mesh
100
+
101
+ def get_texture_maps(
102
+ self, scene_code: torch.Tensor, mesh: Mesh
103
+ ) -> Dict[str, torch.Tensor]:
104
+ assert mesh.has_uv
105
+ # clip space transform
106
+ uv_clip = mesh.v_tex * 2.0 - 1.0
107
+ # pad to four component coordinate
108
+ uv_clip4 = torch.cat(
109
+ (
110
+ uv_clip,
111
+ torch.zeros_like(uv_clip[..., 0:1]),
112
+ torch.ones_like(uv_clip[..., 0:1]),
113
+ ),
114
+ dim=-1,
115
+ )
116
+ # rasterize
117
+ rast, _ = self.ctx.rasterize_one(
118
+ uv_clip4,
119
+ mesh.t_tex_idx,
120
+ (self.cfg.texture_size, self.cfg.texture_size),
121
+ )
122
+
123
+ hole_mask = ~(rast[:, :, 3] > 0)
124
+
125
+ # Interpolate world space position
126
+ gb_pos, _ = self.ctx.interpolate_one(
127
+ mesh.v_pos, rast[None, ...], mesh.t_pos_idx
128
+ )
129
+ gb_pos = gb_pos[0]
130
+
131
+ # Sample out textures from MLP
132
+ tr.start("Query color")
133
+ geo_out = self.renderer.query(scene_code, points=gb_pos)
134
+ tr.end("Query color")
135
+ mat_out = self.renderer.material.export(points=gb_pos, **geo_out)
136
+
137
+ textures = {}
138
+ tr.start("UV padding")
139
+ if "albedo" in mat_out:
140
+ textures["map_Kd"] = uv_padding(
141
+ mat_out["albedo"], hole_mask, self.cfg.uv_padding_size
142
+ )
143
+ else:
144
+ lrm.warn(
145
+ "save_texture is True but no albedo texture found, using default white texture"
146
+ )
147
+ if "metallic" in mat_out:
148
+ textures["map_Pm"] = uv_padding(
149
+ mat_out["metallic"], hole_mask, self.cfg.uv_padding_size
150
+ )
151
+ if "roughness" in mat_out:
152
+ textures["map_Pr"] = uv_padding(
153
+ mat_out["roughness"], hole_mask, self.cfg.uv_padding_size
154
+ )
155
+ if "bump" in mat_out:
156
+ textures["map_Bump"] = uv_padding(
157
+ mat_out["bump"], hole_mask, self.cfg.uv_padding_size
158
+ )
159
+ tr.end("UV padding")
160
+ return textures
161
+
162
+ def __call__(self, names, scene_codes) -> List[ExporterOutput]:
163
+ outputs = []
164
+ for name, scene_code in zip(names, scene_codes):
165
+ mesh = self.get_geometry(scene_code)
166
+ if self.cfg.post_process:
167
+ tr.start("Mesh post-processing")
168
+ mesh = mesh.post_process(self.cfg.post_process_options)
169
+ tr.end("Mesh post-processing")
170
+ if self.cfg.visual == "uv":
171
+ output = self.export_model_with_mtl(
172
+ name, self.cfg.fmt, scene_code, mesh
173
+ )
174
+ elif self.cfg.visual == "vertex":
175
+ output = self.export_model(name, self.cfg.fmt, scene_code, mesh)
176
+ elif self.cfg.visual == "uv-blender":
177
+ output = self.export_model_blender(name, self.cfg.fmt, scene_code, mesh)
178
+ else:
179
+ raise ValueError(f"Unsupported visual format: {self.cfg.visual}")
180
+ outputs.append(output)
181
+ return outputs
182
+
183
+ def export_model_with_mtl(
184
+ self, name: str, fmt: str, scene_code: torch.Tensor, mesh: Mesh
185
+ ) -> ExporterOutput:
186
+ params = {
187
+ "mesh": mesh,
188
+ "save_mat": True,
189
+ "save_normal": self.cfg.save_normal,
190
+ "save_uv": self.cfg.save_uv,
191
+ "save_vertex_color": False,
192
+ "map_Kd": None, # Base Color
193
+ "map_Ks": None, # Specular
194
+ "map_Bump": None, # Normal
195
+ # ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering
196
+ "map_Pm": None, # Metallic
197
+ "map_Pr": None, # Roughness
198
+ "map_format": self.cfg.texture_format,
199
+ }
200
+
201
+ if self.cfg.save_uv:
202
+ mesh.unwrap_uv(
203
+ self.cfg.uv_unwrap_method,
204
+ self.cfg.xatlas_chart_options,
205
+ self.cfg.xatlas_pack_options,
206
+ self.cfg.smartuv_options,
207
+ )
208
+
209
+ if self.cfg.save_texture:
210
+ lrm.info("Exporting textures ...")
211
+ assert self.cfg.save_uv, "save_uv must be True when save_texture is True"
212
+
213
+ with time_recorder_enabled():
214
+ textures = self.get_texture_maps(scene_code, mesh)
215
+ params.update(textures)
216
+ os.makedirs(self.cfg.output_path, exist_ok=True)
217
+ np.savez(f"{self.cfg.output_path}/tex_info.npz", v_tex=mesh.v_tex.cpu().numpy(), t_tex_idx=mesh.t_tex_idx.cpu().numpy())
218
+ return ExporterOutput(
219
+ save_name=f"{self.cfg.save_name}-{name}.{fmt}", save_type=fmt, params=params
220
+ )
221
+
222
+ def export_model(
223
+ self, name: str, fmt: str, scene_code, mesh: Mesh
224
+ ) -> ExporterOutput:
225
+ params = {
226
+ "mesh": mesh,
227
+ "save_mat": False,
228
+ "save_normal": self.cfg.save_normal,
229
+ "save_uv": self.cfg.save_uv,
230
+ "save_vertex_color": False,
231
+ "map_Kd": None, # Base Color
232
+ "map_Ks": None, # Specular
233
+ "map_Bump": None, # Normal
234
+ # ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering
235
+ "map_Pm": None, # Metallic
236
+ "map_Pr": None, # Roughness
237
+ "map_format": self.cfg.texture_format,
238
+ }
239
+
240
+ if self.cfg.save_uv:
241
+ mesh.unwrap_uv(
242
+ self.cfg.uv_unwrap_method,
243
+ self.cfg.xatlas_chart_options,
244
+ self.cfg.xatlas_pack_options,
245
+ self.cfg.smartuv_options,
246
+ )
247
+
248
+ if self.cfg.save_texture:
249
+ lrm.info("Exporting textures ...")
250
+ geo_out = self.renderer.query(scene_code, points=mesh.v_pos)
251
+ mat_out = self.renderer.material.export(points=mesh.v_pos, **geo_out)
252
+
253
+ if "albedo" in mat_out:
254
+ mesh.set_vertex_color(mat_out["albedo"])
255
+ params["save_vertex_color"] = True
256
+ else:
257
+ lrm.warn(
258
+ "save_texture is True but no albedo texture found, not saving vertex color"
259
+ )
260
+
261
+ return ExporterOutput(
262
+ save_name=f"{self.cfg.save_name}-{name}.{fmt}", save_type=fmt, params=params
263
+ )
3D_Stage/lrm/models/isosurface.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ import lrm
7
+ from ..models.mesh import Mesh
8
+ from ..utils.typing import *
9
+ from ..utils.ops import scale_tensor
10
+
11
+
12
+ class IsosurfaceHelper(nn.Module):
13
+ points_range: Tuple[float, float] = (0, 1)
14
+
15
+ @property
16
+ def grid_vertices(self) -> Float[Tensor, "N 3"]:
17
+ raise NotImplementedError
18
+
19
+
20
+ class MarchingCubeCPUHelper(IsosurfaceHelper):
21
+ def __init__(self, resolution: int) -> None:
22
+ super().__init__()
23
+ self.resolution = resolution
24
+ import mcubes
25
+
26
+ self.mc_func: Callable = mcubes.marching_cubes
27
+ self._grid_vertices: Optional[Float[Tensor, "N3 3"]] = None
28
+ self._dummy: Float[Tensor, "..."]
29
+ self.register_buffer(
30
+ "_dummy", torch.zeros(0, dtype=torch.float32), persistent=False
31
+ )
32
+
33
+ @property
34
+ def grid_vertices(self) -> Float[Tensor, "N3 3"]:
35
+ if self._grid_vertices is None:
36
+ # keep the vertices on CPU so that we can support very large resolution
37
+ x, y, z = (
38
+ torch.linspace(*self.points_range, self.resolution),
39
+ torch.linspace(*self.points_range, self.resolution),
40
+ torch.linspace(*self.points_range, self.resolution),
41
+ )
42
+ x, y, z = torch.meshgrid(x, y, z, indexing="ij")
43
+ verts = torch.cat(
44
+ [x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1
45
+ ).reshape(-1, 3)
46
+ self._grid_vertices = verts
47
+ return self._grid_vertices
48
+
49
+ def forward(
50
+ self,
51
+ level: Float[Tensor, "N3 1"],
52
+ deformation: Optional[Float[Tensor, "N3 3"]] = None,
53
+ ) -> Mesh:
54
+ if deformation is not None:
55
+ lrm.warn(
56
+ f"{self.__class__.__name__} does not support deformation. Ignoring."
57
+ )
58
+ level = -level.view(self.resolution, self.resolution, self.resolution)
59
+ print(level.shape, level.min(), level.max())
60
+ v_pos, t_pos_idx = self.mc_func(
61
+ level.detach().cpu().numpy(), 0.0
62
+ ) # transform to numpy
63
+ # test
64
+ v_pos, t_pos_idx = (
65
+ torch.from_numpy(v_pos).float().to(self._dummy.device),
66
+ torch.from_numpy(t_pos_idx.astype(np.int64)).long().to(self._dummy.device),
67
+ ) # transform back to torch tensor on CUDA
68
+ v_pos = v_pos / (self.resolution - 1.0)
69
+ return Mesh(v_pos=v_pos, t_pos_idx=t_pos_idx)
70
+
71
+
72
+ def get_center_boundary_index(verts):
73
+ # Assuming the verts are in range [-1.0, 1.0]
74
+ length_ = torch.linalg.norm(verts**2, ord=2, dim=-1, keepdim=False)
75
+ center_idx = torch.argmin(length_)
76
+ # center_idx = torch.where(length_ < 0.1)[0]
77
+ boundary_neg = verts == verts.max()
78
+ boundary_pos = verts == verts.min()
79
+ boundary = torch.bitwise_or(boundary_pos, boundary_neg)
80
+ boundary = torch.sum(boundary.float(), dim=-1)
81
+ boundary_idx = torch.nonzero(boundary)
82
+ return center_idx.unsqueeze(0), boundary_idx.squeeze(dim=-1)
83
+
84
+
85
+ class MarchingTetrahedraHelper(IsosurfaceHelper):
86
+ def __init__(self, resolution: int, tets_path: str):
87
+ super().__init__()
88
+ self.resolution = resolution
89
+ self.tets_path = tets_path
90
+
91
+ self.triangle_table: Float[Tensor, "..."]
92
+ self.register_buffer(
93
+ "triangle_table",
94
+ torch.as_tensor(
95
+ [
96
+ [-1, -1, -1, -1, -1, -1],
97
+ [1, 0, 2, -1, -1, -1],
98
+ [4, 0, 3, -1, -1, -1],
99
+ [1, 4, 2, 1, 3, 4],
100
+ [3, 1, 5, -1, -1, -1],
101
+ [2, 3, 0, 2, 5, 3],
102
+ [1, 4, 0, 1, 5, 4],
103
+ [4, 2, 5, -1, -1, -1],
104
+ [4, 5, 2, -1, -1, -1],
105
+ [4, 1, 0, 4, 5, 1],
106
+ [3, 2, 0, 3, 5, 2],
107
+ [1, 3, 5, -1, -1, -1],
108
+ [4, 1, 2, 4, 3, 1],
109
+ [3, 0, 4, -1, -1, -1],
110
+ [2, 0, 1, -1, -1, -1],
111
+ [-1, -1, -1, -1, -1, -1],
112
+ ],
113
+ dtype=torch.long,
114
+ ),
115
+ persistent=False,
116
+ )
117
+ self.num_triangles_table: Integer[Tensor, "..."]
118
+ self.register_buffer(
119
+ "num_triangles_table",
120
+ torch.as_tensor(
121
+ [0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long
122
+ ),
123
+ persistent=False,
124
+ )
125
+ self.base_tet_edges: Integer[Tensor, "..."]
126
+ self.register_buffer(
127
+ "base_tet_edges",
128
+ torch.as_tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long),
129
+ persistent=False,
130
+ )
131
+
132
+ tets = np.load(self.tets_path)
133
+ self._grid_vertices: Float[Tensor, "..."]
134
+ self.register_buffer(
135
+ "_grid_vertices",
136
+ torch.from_numpy(tets["vertices"]).float(),
137
+ persistent=False,
138
+ )
139
+ self.indices: Integer[Tensor, "..."]
140
+ self.register_buffer(
141
+ "indices", torch.from_numpy(tets["indices"]).long(), persistent=False
142
+ )
143
+
144
+ self._all_edges: Optional[Integer[Tensor, "Ne 2"]] = None
145
+ self.center_indices, self.boundary_indices = get_center_boundary_index(
146
+ scale_tensor(self.grid_vertices, self.points_range, (-1.0, 1.0))
147
+ )
148
+
149
+ def normalize_grid_deformation(
150
+ self, grid_vertex_offsets: Float[Tensor, "Nv 3"]
151
+ ) -> Float[Tensor, "Nv 3"]:
152
+ return (
153
+ (self.points_range[1] - self.points_range[0])
154
+ / (self.resolution) # half tet size is approximately 1 / self.resolution
155
+ * torch.tanh(grid_vertex_offsets)
156
+ ) # FIXME: hard-coded activation
157
+
158
+ @property
159
+ def grid_vertices(self) -> Float[Tensor, "Nv 3"]:
160
+ return self._grid_vertices
161
+
162
+ @property
163
+ def all_edges(self) -> Integer[Tensor, "Ne 2"]:
164
+ if self._all_edges is None:
165
+ # compute edges on GPU, or it would be VERY SLOW (basically due to the unique operation)
166
+ edges = torch.tensor(
167
+ [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3],
168
+ dtype=torch.long,
169
+ device=self.indices.device,
170
+ )
171
+ _all_edges = self.indices[:, edges].reshape(-1, 2)
172
+ _all_edges_sorted = torch.sort(_all_edges, dim=1)[0]
173
+ _all_edges = torch.unique(_all_edges_sorted, dim=0)
174
+ self._all_edges = _all_edges
175
+ return self._all_edges
176
+
177
+ def sort_edges(self, edges_ex2):
178
+ with torch.no_grad():
179
+ order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long()
180
+ order = order.unsqueeze(dim=1)
181
+
182
+ a = torch.gather(input=edges_ex2, index=order, dim=1)
183
+ b = torch.gather(input=edges_ex2, index=1 - order, dim=1)
184
+
185
+ return torch.stack([a, b], -1)
186
+
187
+ def _forward(self, pos_nx3, sdf_n, tet_fx4):
188
+ with torch.no_grad():
189
+ occ_n = sdf_n > 0
190
+ occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
191
+ occ_sum = torch.sum(occ_fx4, -1)
192
+ valid_tets = (occ_sum > 0) & (occ_sum < 4)
193
+ occ_sum = occ_sum[valid_tets]
194
+
195
+ # find all vertices
196
+ all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2)
197
+ all_edges = self.sort_edges(all_edges)
198
+ unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
199
+
200
+ unique_edges = unique_edges.long()
201
+ mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
202
+ mapping = (
203
+ torch.ones(
204
+ (unique_edges.shape[0]), dtype=torch.long, device=pos_nx3.device
205
+ )
206
+ * -1
207
+ )
208
+ mapping[mask_edges] = torch.arange(
209
+ mask_edges.sum(), dtype=torch.long, device=pos_nx3.device
210
+ )
211
+ idx_map = mapping[idx_map] # map edges to verts
212
+
213
+ interp_v = unique_edges[mask_edges]
214
+ edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
215
+ edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
216
+ edges_to_interp_sdf[:, -1] *= -1
217
+
218
+ denominator = edges_to_interp_sdf.sum(1, keepdim=True)
219
+
220
+ edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
221
+ verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
222
+
223
+ idx_map = idx_map.reshape(-1, 6)
224
+
225
+ v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=pos_nx3.device))
226
+ tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
227
+ num_triangles = self.num_triangles_table[tetindex]
228
+
229
+ # Generate triangle indices
230
+ faces = torch.cat(
231
+ (
232
+ torch.gather(
233
+ input=idx_map[num_triangles == 1],
234
+ dim=1,
235
+ index=self.triangle_table[tetindex[num_triangles == 1]][:, :3],
236
+ ).reshape(-1, 3),
237
+ torch.gather(
238
+ input=idx_map[num_triangles == 2],
239
+ dim=1,
240
+ index=self.triangle_table[tetindex[num_triangles == 2]][:, :6],
241
+ ).reshape(-1, 3),
242
+ ),
243
+ dim=0,
244
+ )
245
+
246
+ return verts, faces
247
+
248
+ def forward(
249
+ self,
250
+ level: Float[Tensor, "N3 1"],
251
+ deformation: Optional[Float[Tensor, "N3 3"]] = None,
252
+ ) -> Mesh:
253
+ if deformation is not None:
254
+ grid_vertices = self.grid_vertices + self.normalize_grid_deformation(
255
+ deformation
256
+ )
257
+ else:
258
+ grid_vertices = self.grid_vertices
259
+
260
+ v_pos, t_pos_idx = self._forward(grid_vertices, level, self.indices)
261
+
262
+ mesh = Mesh(
263
+ v_pos=v_pos,
264
+ t_pos_idx=t_pos_idx,
265
+ # extras
266
+ grid_vertices=grid_vertices,
267
+ tet_edges=self.all_edges,
268
+ grid_level=level,
269
+ grid_deformation=deformation,
270
+ )
271
+
272
+ return mesh
3D_Stage/lrm/models/lpips.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any
3
+ import lpips
4
+
5
+ from ..utils.ops import scale_tensor
6
+ from ..utils.misc import get_device
7
+
8
+
9
+ class LPIPS:
10
+ def __init__(self):
11
+ self.model = lpips.LPIPS(net="vgg").to(get_device())
12
+ self.model.eval()
13
+ for params in self.model.parameters():
14
+ params.requires_grad = False
15
+ self.model_input_range = (-1, 1)
16
+
17
+ def __call__(self, x1, x2, return_layers=False, input_range=(0, 1)):
18
+ x1 = scale_tensor(x1, input_range, self.model_input_range)
19
+ x2 = scale_tensor(x2, input_range, self.model_input_range)
20
+ return self.model(x1, x2, retPerLayer=return_layers, normalize=False)
3D_Stage/lrm/models/materials/__init__.py ADDED
File without changes
3D_Stage/lrm/models/materials/base.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from dataclasses import dataclass, field
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ import lrm
9
+ from ...utils.base import BaseModule
10
+ from ...utils.typing import *
11
+
12
+
13
+ class BaseMaterial(BaseModule):
14
+ @dataclass
15
+ class Config(BaseModule.Config):
16
+ pass
17
+
18
+ cfg: Config
19
+ requires_normal: bool = False
20
+ requires_tangent: bool = False
21
+
22
+ def configure(self):
23
+ pass
24
+
25
+ def forward(self, *args, **kwargs) -> Float[Tensor, "*B 3"]:
26
+ raise NotImplementedError
27
+
28
+ def export(self, *args, **kwargs) -> Dict[str, Any]:
29
+ return {}
3D_Stage/lrm/models/materials/no_material.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from dataclasses import dataclass, field
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ import lrm
9
+ from .base import BaseMaterial
10
+ from ..networks import get_encoding, get_mlp
11
+ from ...utils.ops import dot, get_activation
12
+ from ...utils.typing import *
13
+
14
+
15
+ class NoMaterial(BaseMaterial):
16
+ @dataclass
17
+ class Config(BaseMaterial.Config):
18
+ n_output_dims: int = 3
19
+ color_activation: str = "sigmoid"
20
+ input_feature_dims: Optional[int] = None
21
+ mlp_network_config: Optional[dict] = None
22
+ requires_normal: bool = False
23
+
24
+ cfg: Config
25
+
26
+ def configure(self) -> None:
27
+ self.use_network = False
28
+ if (
29
+ self.cfg.input_feature_dims is not None
30
+ and self.cfg.mlp_network_config is not None
31
+ ):
32
+ self.network = get_mlp(
33
+ self.cfg.input_feature_dims,
34
+ self.cfg.n_output_dims,
35
+ self.cfg.mlp_network_config,
36
+ )
37
+ self.use_network = True
38
+ self.requires_normal = self.cfg.requires_normal
39
+
40
+ def forward(
41
+ self, features: Float[Tensor, "B ... Nf"], **kwargs
42
+ ) -> Float[Tensor, "B ... Nc"]:
43
+ if not self.use_network:
44
+ assert (
45
+ features.shape[-1] == self.cfg.n_output_dims
46
+ ), f"Expected {self.cfg.n_output_dims} output dims, only got {features.shape[-1]} dims input."
47
+ color = get_activation(self.cfg.color_activation)(features)
48
+ else:
49
+ color = self.network(features.view(-1, features.shape[-1])).view(
50
+ *features.shape[:-1], self.cfg.n_output_dims
51
+ )
52
+ color = get_activation(self.cfg.color_activation)(color)
53
+ return color
54
+
55
+ def export(self, features: Float[Tensor, "*N Nf"], **kwargs) -> Dict[str, Any]:
56
+ color = self(features, **kwargs).clamp(0, 1)
57
+ assert color.shape[-1] >= 3, "Output color must have at least 3 channels"
58
+ if color.shape[-1] > 3:
59
+ lrm.warn("Output color has >3 channels, treating the first 3 as RGB")
60
+ return {"albedo": color[..., :3]}
3D_Stage/lrm/models/mesh.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ import lrm
9
+ from ..utils.ops import dot
10
+ from ..utils.typing import *
11
+ from ..utils.misc import time_recorder as tr, time_recorder_enabled
12
+
13
+
14
+ class Mesh:
15
+ def __init__(
16
+ self, v_pos: Float[Tensor, "Nv 3"], t_pos_idx: Integer[Tensor, "Nf 3"], **kwargs
17
+ ) -> None:
18
+ self.v_pos: Float[Tensor, "Nv 3"] = v_pos
19
+ self.t_pos_idx: Integer[Tensor, "Nf 3"] = t_pos_idx
20
+ self._v_nrm: Optional[Float[Tensor, "Nv 3"]] = None
21
+ self._v_tng: Optional[Float[Tensor, "Nv 3"]] = None
22
+ self._v_tex: Optional[Float[Tensor, "Nt 3"]] = None
23
+ self._t_tex_idx: Optional[Float[Tensor, "Nf 3"]] = None
24
+ self._v_rgb: Optional[Float[Tensor, "Nv 3"]] = None
25
+ self._edges: Optional[Integer[Tensor, "Ne 2"]] = None
26
+ self.extras: Dict[str, Any] = {}
27
+ for k, v in kwargs.items():
28
+ self.add_extra(k, v)
29
+
30
+ def add_extra(self, k, v) -> None:
31
+ self.extras[k] = v
32
+
33
+ def remove_outlier(self, outlier_n_faces_threshold: Union[int, float]) -> Mesh:
34
+ if self.requires_grad:
35
+ lrm.debug("Mesh is differentiable, not removing outliers")
36
+ return self
37
+
38
+ # use trimesh to first split the mesh into connected components
39
+ # then remove the components with less than n_face_threshold faces
40
+ import trimesh
41
+
42
+ # construct a trimesh object
43
+ mesh = trimesh.Trimesh(
44
+ vertices=self.v_pos.detach().cpu().numpy(),
45
+ faces=self.t_pos_idx.detach().cpu().numpy(),
46
+ )
47
+
48
+ # split the mesh into connected components
49
+ components = mesh.split(only_watertight=False)
50
+ # log the number of faces in each component
51
+ lrm.debug(
52
+ "Mesh has {} components, with faces: {}".format(
53
+ len(components), [c.faces.shape[0] for c in components]
54
+ )
55
+ )
56
+
57
+ n_faces_threshold: int
58
+ if isinstance(outlier_n_faces_threshold, float):
59
+ # set the threshold to the number of faces in the largest component multiplied by outlier_n_faces_threshold
60
+ n_faces_threshold = int(
61
+ max([c.faces.shape[0] for c in components]) * outlier_n_faces_threshold
62
+ )
63
+ else:
64
+ # set the threshold directly to outlier_n_faces_threshold
65
+ n_faces_threshold = outlier_n_faces_threshold
66
+
67
+ # log the threshold
68
+ lrm.debug(
69
+ "Removing components with less than {} faces".format(n_faces_threshold)
70
+ )
71
+
72
+ # remove the components with less than n_face_threshold faces
73
+ components = [c for c in components if c.faces.shape[0] >= n_faces_threshold]
74
+
75
+ # log the number of faces in each component after removing outliers
76
+ lrm.debug(
77
+ "Mesh has {} components after removing outliers, with faces: {}".format(
78
+ len(components), [c.faces.shape[0] for c in components]
79
+ )
80
+ )
81
+ # merge the components
82
+ mesh = trimesh.util.concatenate(components)
83
+
84
+ # convert back to our mesh format
85
+ v_pos = torch.from_numpy(mesh.vertices).to(self.v_pos)
86
+ t_pos_idx = torch.from_numpy(mesh.faces).to(self.t_pos_idx)
87
+
88
+ clean_mesh = Mesh(v_pos, t_pos_idx)
89
+ # keep the extras unchanged
90
+
91
+ if len(self.extras) > 0:
92
+ clean_mesh.extras = self.extras
93
+ lrm.debug(
94
+ f"The following extra attributes are inherited from the original mesh unchanged: {list(self.extras.keys())}"
95
+ )
96
+ return clean_mesh
97
+
98
+ def subdivide(self):
99
+ if self.requires_grad:
100
+ lrm.debug("Mesh is differentiable, not performing subdivision")
101
+ return self
102
+
103
+ import trimesh
104
+
105
+ mesh = trimesh.Trimesh(
106
+ vertices=self.v_pos.detach().cpu().numpy(),
107
+ faces=self.t_pos_idx.detach().cpu().numpy(),
108
+ )
109
+
110
+ mesh.subdivide_loop()
111
+
112
+ v_pos = torch.from_numpy(mesh.vertices).to(self.v_pos)
113
+ t_pos_idx = torch.from_numpy(mesh.faces).to(self.t_pos_idx)
114
+
115
+ subdivided_mesh = Mesh(v_pos, t_pos_idx)
116
+
117
+ if len(self.extras) > 0:
118
+ subdivided_mesh.extras = self.extras
119
+ lrm.debug(
120
+ f"The following extra attributes are inherited from the original mesh unchanged: {list(self.extras.keys())}"
121
+ )
122
+
123
+ return subdivided_mesh
124
+
125
+ def post_process(self, options):
126
+ if self.requires_grad:
127
+ lrm.debug("Mesh is differentiable, not performing post processing")
128
+ return self
129
+
130
+ from extern.mesh_process.MeshProcess import process_mesh
131
+
132
+ v_pos, t_pos_idx = process_mesh(
133
+ vertices=self.v_pos.detach().cpu().numpy(),
134
+ faces=self.t_pos_idx.detach().cpu().numpy(),
135
+ **options,
136
+ )
137
+
138
+ v_pos = torch.from_numpy(v_pos).to(self.v_pos).contiguous()
139
+ t_pos_idx = torch.from_numpy(t_pos_idx).to(self.t_pos_idx).contiguous()
140
+
141
+ processed_mesh = Mesh(v_pos, t_pos_idx)
142
+
143
+ if len(self.extras) > 0:
144
+ processed_mesh.extras = self.extras
145
+ lrm.debug(
146
+ f"The following extra attributes are inherited from the original mesh unchanged: {list(self.extras.keys())}"
147
+ )
148
+
149
+ return processed_mesh
150
+
151
+ @property
152
+ def requires_grad(self):
153
+ return self.v_pos.requires_grad
154
+
155
+ @property
156
+ def v_nrm(self):
157
+ if self._v_nrm is None:
158
+ self._v_nrm = self._compute_vertex_normal()
159
+ return self._v_nrm
160
+
161
+ @property
162
+ def v_tng(self):
163
+ if self._v_tng is None:
164
+ self._v_tng = self._compute_vertex_tangent()
165
+ return self._v_tng
166
+
167
+ @property
168
+ def v_tex(self):
169
+ if self._v_tex is None:
170
+ self._v_tex, self._t_tex_idx = self._unwrap_uv()
171
+ return self._v_tex
172
+
173
+ @property
174
+ def t_tex_idx(self):
175
+ if self._t_tex_idx is None:
176
+ self._v_tex, self._t_tex_idx = self._unwrap_uv()
177
+ return self._t_tex_idx
178
+
179
+ @property
180
+ def v_rgb(self):
181
+ return self._v_rgb
182
+
183
+ @property
184
+ def edges(self):
185
+ if self._edges is None:
186
+ self._edges = self._compute_edges()
187
+ return self._edges
188
+
189
+ def _compute_vertex_normal(self):
190
+ i0 = self.t_pos_idx[:, 0]
191
+ i1 = self.t_pos_idx[:, 1]
192
+ i2 = self.t_pos_idx[:, 2]
193
+
194
+ v0 = self.v_pos[i0, :]
195
+ v1 = self.v_pos[i1, :]
196
+ v2 = self.v_pos[i2, :]
197
+
198
+ face_normals = torch.cross(v1 - v0, v2 - v0)
199
+
200
+ # Splat face normals to vertices
201
+ v_nrm = torch.zeros_like(self.v_pos)
202
+ v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
203
+ v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
204
+ v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
205
+
206
+ # Normalize, replace zero (degenerated) normals with some default value
207
+ v_nrm = torch.where(
208
+ dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
209
+ )
210
+ v_nrm = F.normalize(v_nrm, dim=1)
211
+
212
+ if torch.is_anomaly_enabled():
213
+ assert torch.all(torch.isfinite(v_nrm))
214
+
215
+ return v_nrm
216
+
217
+ def _compute_vertex_tangent(self):
218
+ vn_idx = [None] * 3
219
+ pos = [None] * 3
220
+ tex = [None] * 3
221
+ for i in range(0, 3):
222
+ pos[i] = self.v_pos[self.t_pos_idx[:, i]]
223
+ tex[i] = self.v_tex[self.t_tex_idx[:, i]]
224
+ # t_nrm_idx is always the same as t_pos_idx
225
+ vn_idx[i] = self.t_pos_idx[:, i]
226
+
227
+ tangents = torch.zeros_like(self.v_nrm)
228
+ tansum = torch.zeros_like(self.v_nrm)
229
+
230
+ # Compute tangent space for each triangle
231
+ uve1 = tex[1] - tex[0]
232
+ uve2 = tex[2] - tex[0]
233
+ pe1 = pos[1] - pos[0]
234
+ pe2 = pos[2] - pos[0]
235
+
236
+ nom = pe1 * uve2[..., 1:2] - pe2 * uve1[..., 1:2]
237
+ denom = uve1[..., 0:1] * uve2[..., 1:2] - uve1[..., 1:2] * uve2[..., 0:1]
238
+
239
+ # Avoid division by zero for degenerated texture coordinates
240
+ tang = nom / torch.where(
241
+ denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6)
242
+ )
243
+
244
+ # Update all 3 vertices
245
+ for i in range(0, 3):
246
+ idx = vn_idx[i][:, None].repeat(1, 3)
247
+ tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
248
+ tansum.scatter_add_(
249
+ 0, idx, torch.ones_like(tang)
250
+ ) # tansum[n_i] = tansum[n_i] + 1
251
+ tangents = tangents / tansum
252
+
253
+ # Normalize and make sure tangent is perpendicular to normal
254
+ tangents = F.normalize(tangents, dim=1)
255
+ tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm)
256
+
257
+ if torch.is_anomaly_enabled():
258
+ assert torch.all(torch.isfinite(tangents))
259
+
260
+ return tangents
261
+
262
+ def _unwrap_uv_open3d(
263
+ self
264
+ ):
265
+ import open3d as o3d
266
+ mesh = o3d.t.geometry.TriangleMesh()
267
+ mesh.vertex.positions = o3d.core.Tensor(self.v_pos.detach().cpu().numpy())
268
+ mesh.triangle.indices = o3d.core.Tensor(self.t_pos_idx.cpu().numpy())
269
+ mesh.compute_uvatlas(size=1024)
270
+ texture_uvs = torch.from_numpy(mesh.triangle.texture_uvs.numpy()).reshape(-1, 2).cuda()
271
+ indices = torch.arange(self.t_pos_idx.shape[0] * 3).reshape(-1, 3).to(torch.int64).cuda()
272
+ # Add a wood texture and visualize
273
+ return texture_uvs, indices
274
+
275
+ def _unwrap_uv_xatlas(
276
+ self, xatlas_chart_options: dict = {}, xatlas_pack_options: dict = {}
277
+ ):
278
+ lrm.info("Using xatlas to perform UV unwrapping, may take a while ...")
279
+
280
+ import xatlas
281
+
282
+ atlas = xatlas.Atlas()
283
+ atlas.add_mesh(
284
+ self.v_pos.detach().cpu().numpy(),
285
+ self.t_pos_idx.cpu().numpy(),
286
+ )
287
+ co = xatlas.ChartOptions()
288
+ po = xatlas.PackOptions()
289
+ for k, v in xatlas_chart_options.items():
290
+ setattr(co, k, v)
291
+ for k, v in xatlas_pack_options.items():
292
+ setattr(po, k, v)
293
+ atlas.generate(co, po)
294
+ vmapping, indices, uvs = atlas.get_mesh(0)
295
+ vmapping = (
296
+ torch.from_numpy(
297
+ vmapping.astype(np.uint64, casting="same_kind").view(np.int64)
298
+ )
299
+ .to(self.v_pos.device)
300
+ .long()
301
+ )
302
+ uvs = torch.from_numpy(uvs).to(self.v_pos.device).float()
303
+ indices = (
304
+ torch.from_numpy(
305
+ indices.astype(np.uint64, casting="same_kind").view(np.int64)
306
+ )
307
+ .to(self.v_pos.device)
308
+ .long()
309
+ )
310
+ return uvs, indices
311
+
312
+ def _unwrap_uv_smartuv(self, options: dict = {}):
313
+ from extern.mesh_process.MeshProcess import (
314
+ mesh_to_bpy,
315
+ get_uv_from_bpy,
316
+ bpy_context,
317
+ bpy_export,
318
+ )
319
+ from lrm.utils.misc import time_recorder as tr
320
+
321
+ v_pos, t_pos_idx = self.v_pos.cpu().numpy(), self.t_pos_idx.cpu().numpy()
322
+ with bpy_context():
323
+ mesh_bpy = mesh_to_bpy("_", v_pos, t_pos_idx)
324
+ v_tex = get_uv_from_bpy(mesh_bpy, **options).astype(np.float32)
325
+
326
+ assert v_tex.shape[0] == self.t_pos_idx.shape[0] * 3
327
+
328
+ t_tex_idx = torch.arange(
329
+ self.t_pos_idx.shape[0] * 3, device=self.t_pos_idx.device, dtype=torch.long
330
+ ).reshape(-1, 3)
331
+
332
+ """
333
+ # super efficient de-duplication
334
+ v_tex_u_uint32 = v_tex[..., 0].view(np.uint32)
335
+ v_tex_v_uint32 = v_tex[..., 1].view(np.uint32)
336
+ v_hashed = (v_tex_u_uint32.astype(np.uint64) << 32) | v_tex_v_uint32
337
+ v_hashed = torch.from_numpy(v_hashed.view(np.int64)).to(self.v_pos.device)
338
+
339
+ v_tex = torch.from_numpy(v_tex).to(
340
+ device=self.v_pos.device, dtype=torch.float32
341
+ )
342
+ t_pos_idx_f3 = torch.arange(
343
+ self.t_pos_idx.shape[0] * 3, device=self.t_pos_idx.device, dtype=torch.long
344
+ ).reshape(-1, 3)
345
+ v_pos_f3 = self.v_pos[self.t_pos_idx].reshape(-1, 3)
346
+
347
+ # super efficient de-duplication
348
+ v_hashed_dedup, inverse_indices = torch.unique(v_hashed, return_inverse=True)
349
+ dedup_size, full_size = v_hashed_dedup.shape[0], inverse_indices.shape[0]
350
+ indices = torch.scatter_reduce(
351
+ torch.full(
352
+ [dedup_size],
353
+ fill_value=full_size,
354
+ device=inverse_indices.device,
355
+ dtype=torch.long,
356
+ ),
357
+ index=inverse_indices,
358
+ src=torch.arange(
359
+ full_size, device=inverse_indices.device, dtype=torch.int64
360
+ ),
361
+ dim=0,
362
+ reduce="amin",
363
+ )
364
+ v_tex = v_tex[indices]
365
+ t_tex_idx = inverse_indices.reshape(-1, 3)
366
+
367
+ v_pos = v_pos_f3[indices]
368
+ t_pos_idx = inverse_indices[t_pos_idx_f3]
369
+ """
370
+
371
+ return self.v_pos, self.t_pos_idx, v_tex, t_tex_idx
372
+
373
+ def unwrap_uv(
374
+ self,
375
+ method: str,
376
+ xatlas_chart_options: dict = {},
377
+ xatlas_pack_options: dict = {},
378
+ smartuv_options: dict = {},
379
+ ):
380
+ if method == "xatlas":
381
+ with time_recorder_enabled():
382
+ tr.start("UV unwrapping xatlas")
383
+ self._v_tex, self._t_tex_idx = self._unwrap_uv_xatlas(
384
+ xatlas_chart_options, xatlas_pack_options
385
+ )
386
+ tr.end("UV unwrapping xatlas")
387
+ elif method == "open3d":
388
+ with time_recorder_enabled():
389
+ tr.start("UV unwrapping o3d")
390
+ self._v_tex, self._t_tex_idx = self._unwrap_uv_open3d()
391
+ tr.end("UV unwrapping o3d")
392
+ elif method == "smartuv":
393
+ with time_recorder_enabled():
394
+ tr.start("UV unwrapping smartuv")
395
+ (
396
+ self.v_pos,
397
+ self.t_pos_idx,
398
+ self._v_tex,
399
+ self._t_tex_idx,
400
+ ) = self._unwrap_uv_smartuv(smartuv_options)
401
+ tr.end("UV unwrapping smartuv")
402
+ else:
403
+ raise NotImplementedError
404
+
405
+ def set_vertex_color(self, v_rgb):
406
+ assert v_rgb.shape[0] == self.v_pos.shape[0]
407
+ self._v_rgb = v_rgb
408
+
409
+ def set_uv(self, v_tex, t_tex_idx):
410
+ self._v_tex = v_tex
411
+ self._t_tex_idx = t_tex_idx
412
+
413
+ @property
414
+ def has_uv(self):
415
+ return self._v_tex is not None and self._t_tex_idx is not None
416
+
417
+ def _compute_edges(self):
418
+ # Compute edges
419
+ edges = torch.cat(
420
+ [
421
+ self.t_pos_idx[:, [0, 1]],
422
+ self.t_pos_idx[:, [1, 2]],
423
+ self.t_pos_idx[:, [2, 0]],
424
+ ],
425
+ dim=0,
426
+ )
427
+ edges = edges.sort()[0]
428
+ edges = torch.unique(edges, dim=0)
429
+ return edges
430
+
431
+ def normal_consistency(self) -> Float[Tensor, ""]:
432
+ edge_nrm: Float[Tensor, "Ne 2 3"] = self.v_nrm[self.edges]
433
+ nc = (
434
+ 1.0 - torch.cosine_similarity(edge_nrm[:, 0], edge_nrm[:, 1], dim=-1)
435
+ ).mean()
436
+ return nc
437
+
438
+ def _laplacian_uniform(self):
439
+ # from stable-dreamfusion
440
+ # https://github.com/ashawkey/stable-dreamfusion/blob/8fb3613e9e4cd1ded1066b46e80ca801dfb9fd06/nerf/renderer.py#L224
441
+ verts, faces = self.v_pos, self.t_pos_idx
442
+
443
+ V = verts.shape[0]
444
+ F = faces.shape[0]
445
+
446
+ # Neighbor indices
447
+ ii = faces[:, [1, 2, 0]].flatten()
448
+ jj = faces[:, [2, 0, 1]].flatten()
449
+ adj = torch.stack([torch.cat([ii, jj]), torch.cat([jj, ii])], dim=0).unique(
450
+ dim=1
451
+ )
452
+ adj_values = torch.ones(adj.shape[1]).to(verts)
453
+
454
+ # Diagonal indices
455
+ diag_idx = adj[0]
456
+
457
+ # Build the sparse matrix
458
+ idx = torch.cat((adj, torch.stack((diag_idx, diag_idx), dim=0)), dim=1)
459
+ values = torch.cat((-adj_values, adj_values))
460
+
461
+ # The coalesce operation sums the duplicate indices, resulting in the
462
+ # correct diagonal
463
+ return torch.sparse_coo_tensor(idx, values, (V, V)).coalesce()
464
+
465
+ def laplacian(self) -> Float[Tensor, ""]:
466
+ with torch.no_grad():
467
+ L = self._laplacian_uniform()
468
+ loss = L.mm(self.v_pos)
469
+ loss = loss.norm(dim=1)
470
+ loss = loss.mean()
471
+ return loss
3D_Stage/lrm/models/networks.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from copy import deepcopy
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange
7
+
8
+ from ..utils.base import BaseModule
9
+ from ..utils.ops import get_activation
10
+ from ..utils.typing import *
11
+
12
+
13
+ class TriplaneUpsampleNetwork(BaseModule):
14
+ @dataclass
15
+ class Config(BaseModule.Config):
16
+ in_channels: int = 1024
17
+ out_channels: int = 80
18
+
19
+ cfg: Config
20
+
21
+ def configure(self) -> None:
22
+ super().configure()
23
+ self.upsample = nn.ConvTranspose2d(
24
+ self.cfg.in_channels, self.cfg.out_channels, kernel_size=2, stride=2
25
+ )
26
+
27
+ def forward(
28
+ self, triplanes: Float[Tensor, "B 3 Ci Hp Wp"]
29
+ ) -> Float[Tensor, "B 3 Co Hp2 Wp2"]:
30
+ triplanes_up = rearrange(
31
+ self.upsample(
32
+ rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
33
+ ),
34
+ "(B Np) Co Hp Wp -> B Np Co Hp Wp",
35
+ Np=3,
36
+ )
37
+ return triplanes_up
38
+
39
+
40
+ class MLP(nn.Module):
41
+ def __init__(
42
+ self,
43
+ dim_in: int,
44
+ dim_out: int,
45
+ n_neurons: int,
46
+ n_hidden_layers: int,
47
+ activation: str = "relu",
48
+ output_activation: Optional[str] = None,
49
+ bias: bool = True,
50
+ weight_init: Optional[str] = "kaiming_uniform",
51
+ bias_init: Optional[str] = None,
52
+ ):
53
+ super().__init__()
54
+ layers = [
55
+ self.make_linear(
56
+ dim_in,
57
+ n_neurons,
58
+ is_first=True,
59
+ is_last=False,
60
+ bias=bias,
61
+ weight_init=weight_init,
62
+ bias_init=bias_init,
63
+ ),
64
+ self.make_activation(activation),
65
+ ]
66
+ for i in range(n_hidden_layers - 1):
67
+ layers += [
68
+ self.make_linear(
69
+ n_neurons,
70
+ n_neurons,
71
+ is_first=False,
72
+ is_last=False,
73
+ bias=bias,
74
+ weight_init=weight_init,
75
+ bias_init=bias_init,
76
+ ),
77
+ self.make_activation(activation),
78
+ ]
79
+ layers += [
80
+ self.make_linear(
81
+ n_neurons,
82
+ dim_out,
83
+ is_first=False,
84
+ is_last=True,
85
+ bias=bias,
86
+ weight_init=weight_init,
87
+ bias_init=bias_init,
88
+ )
89
+ ]
90
+ self.layers = nn.Sequential(*layers)
91
+ self.output_activation = get_activation(output_activation)
92
+
93
+ def forward(self, x):
94
+ x = self.layers(x)
95
+ x = self.output_activation(x)
96
+ return x
97
+
98
+ def make_linear(
99
+ self,
100
+ dim_in,
101
+ dim_out,
102
+ is_first,
103
+ is_last,
104
+ bias=True,
105
+ weight_init=None,
106
+ bias_init=None,
107
+ ):
108
+ layer = nn.Linear(dim_in, dim_out, bias=bias)
109
+
110
+ if weight_init is None:
111
+ pass
112
+ elif weight_init == "kaiming_uniform":
113
+ torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity="relu")
114
+ else:
115
+ raise NotImplementedError
116
+
117
+ if bias:
118
+ if bias_init is None:
119
+ pass
120
+ elif bias_init == "zero":
121
+ torch.nn.init.zeros_(layer.bias)
122
+ else:
123
+ raise NotImplementedError
124
+
125
+ return layer
126
+
127
+ def make_activation(self, activation):
128
+ if activation == "relu":
129
+ return nn.ReLU(inplace=True)
130
+ elif activation == "silu":
131
+ return nn.SiLU(inplace=True)
132
+ else:
133
+ raise NotImplementedError
134
+
135
+
136
+ @dataclass
137
+ class HeadSpec:
138
+ name: str
139
+ out_channels: int
140
+ n_hidden_layers: int
141
+ output_activation: Optional[str] = None
142
+
143
+
144
+ class MultiHeadMLP(BaseModule):
145
+ @dataclass
146
+ class Config(BaseModule.Config):
147
+ in_channels: int = 0
148
+ n_neurons: int = 0
149
+ n_hidden_layers_share: int = 0
150
+ heads: List[HeadSpec] = field(default_factory=lambda: [])
151
+ activation: str = "relu"
152
+ bias: bool = True
153
+ weight_init: Optional[str] = "kaiming_uniform"
154
+ bias_init: Optional[str] = None
155
+ chunk_mode: Optional[str] = None
156
+ chunk_size: int = -1
157
+
158
+ cfg: Config
159
+
160
+ def configure(self) -> None:
161
+ super().configure()
162
+ shared_layers = [
163
+ self.make_linear(
164
+ self.cfg.in_channels,
165
+ self.cfg.n_neurons,
166
+ bias=self.cfg.bias,
167
+ weight_init=self.cfg.weight_init,
168
+ bias_init=self.cfg.bias_init,
169
+ ),
170
+ self.make_activation(self.cfg.activation),
171
+ ]
172
+ for i in range(self.cfg.n_hidden_layers_share - 1):
173
+ shared_layers += [
174
+ self.make_linear(
175
+ self.cfg.n_neurons,
176
+ self.cfg.n_neurons,
177
+ bias=self.cfg.bias,
178
+ weight_init=self.cfg.weight_init,
179
+ bias_init=self.cfg.bias_init,
180
+ ),
181
+ self.make_activation(self.cfg.activation),
182
+ ]
183
+ self.shared_layers = nn.Sequential(*shared_layers)
184
+
185
+ assert len(self.cfg.heads) > 0
186
+ heads = {}
187
+ for head in self.cfg.heads:
188
+ head_layers = []
189
+ for i in range(head.n_hidden_layers):
190
+ head_layers += [
191
+ self.make_linear(
192
+ self.cfg.n_neurons,
193
+ self.cfg.n_neurons,
194
+ bias=self.cfg.bias,
195
+ weight_init=self.cfg.weight_init,
196
+ bias_init=self.cfg.bias_init,
197
+ ),
198
+ self.make_activation(self.cfg.activation),
199
+ ]
200
+ head_layers += [
201
+ self.make_linear(
202
+ self.cfg.n_neurons,
203
+ head.out_channels,
204
+ bias=self.cfg.bias,
205
+ weight_init=self.cfg.weight_init,
206
+ bias_init=self.cfg.bias_init,
207
+ ),
208
+ ]
209
+ heads[head.name] = nn.Sequential(*head_layers)
210
+ self.heads = nn.ModuleDict(heads)
211
+
212
+ if self.cfg.chunk_mode is not None:
213
+ assert self.cfg.chunk_size > 0
214
+
215
+ def make_linear(
216
+ self,
217
+ dim_in,
218
+ dim_out,
219
+ bias=True,
220
+ weight_init=None,
221
+ bias_init=None,
222
+ ):
223
+ layer = nn.Linear(dim_in, dim_out, bias=bias)
224
+
225
+ if weight_init is None:
226
+ pass
227
+ elif weight_init == "kaiming_uniform":
228
+ torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity="relu")
229
+ else:
230
+ raise NotImplementedError
231
+
232
+ if bias:
233
+ if bias_init is None:
234
+ pass
235
+ elif bias_init == "zero":
236
+ torch.nn.init.zeros_(layer.bias)
237
+ else:
238
+ raise NotImplementedError
239
+
240
+ return layer
241
+
242
+ def make_activation(self, activation):
243
+ if activation == "relu":
244
+ return nn.ReLU(inplace=True)
245
+ elif activation == "silu":
246
+ return nn.SiLU(inplace=True)
247
+ else:
248
+ raise NotImplementedError
249
+
250
+ def forward(
251
+ self, x, include: Optional[List] = None, exclude: Optional[List] = None
252
+ ):
253
+ inp_shape = x.shape[:-1]
254
+ x = x.reshape(-1, x.shape[-1])
255
+
256
+ if self.cfg.chunk_mode is None:
257
+ shared_features = self.shared_layers(x)
258
+ elif self.cfg.chunk_mode == "deferred":
259
+ shared_features = DeferredFunc.apply(
260
+ self.shared_layers, x, self.cfg.chunk_size
261
+ )
262
+ elif self.cfg.chunk_mode == "checkpointing":
263
+ shared_features = apply_batch_checkpointing(
264
+ self.shared_layers, x, self.cfg.chunk_size
265
+ )
266
+ else:
267
+ raise NotImplementedError
268
+
269
+ shared_features = shared_features.reshape(*inp_shape, -1)
270
+
271
+ if include is not None and exclude is not None:
272
+ raise ValueError("Cannot specify both include and exclude.")
273
+ if include is not None:
274
+ heads = [h for h in self.cfg.heads if h.name in include]
275
+ elif exclude is not None:
276
+ heads = [h for h in self.cfg.heads if h.name not in exclude]
277
+ else:
278
+ heads = self.cfg.heads
279
+
280
+ out = {
281
+ head.name: get_activation(head.output_activation)(
282
+ self.heads[head.name](shared_features)
283
+ )
284
+ for head in heads
285
+ }
286
+ """
287
+ # TypeError
288
+ if self.cfg.chunk_mode is None:
289
+ out = {
290
+ head.name: get_activation(head.output_activation)(
291
+ self.heads[head.name](shared_features)
292
+ )
293
+ for head in heads
294
+ }
295
+ elif self.cfg.chunk_mode == "deferred":
296
+ out = {
297
+ head.name: get_activation(head.output_activation)(
298
+ DeferredFunc.apply(self.heads[head.name], shared_features, self.cfg.chunk_size)
299
+ )
300
+ for head in heads
301
+ }
302
+ else:
303
+ raise NotImplementedError
304
+ """
305
+ return out
306
+
307
+
308
+ class DeferredFunc(torch.autograd.Function):
309
+ # Note that forward, setup_context, and backward are @staticmethods
310
+ @staticmethod
311
+ def forward(ctx, model, x, chunk_size):
312
+ model_copy = deepcopy(model)
313
+ model_copy.requires_grad_(False)
314
+
315
+ ret = []
316
+ x_split = torch.split(x, chunk_size, dim=0)
317
+
318
+ with torch.no_grad():
319
+ for cur_x in x_split:
320
+ ret.append(model_copy(cur_x))
321
+
322
+ ctx.model = model
323
+ ctx.save_for_backward(x.detach(), torch.as_tensor(chunk_size))
324
+
325
+ ret = torch.cat(ret, dim=0)
326
+
327
+ return ret
328
+
329
+ # This function has only a single output, so it gets only one gradient
330
+ @staticmethod
331
+ def backward(ctx, grad_output):
332
+ model = ctx.model
333
+ x, chunk_size = ctx.saved_tensors
334
+ chunk_size = chunk_size.item()
335
+
336
+ model_copy = deepcopy(model)
337
+
338
+ x_split = torch.split(x, chunk_size, dim=0)
339
+ grad_output_split = torch.split(grad_output, chunk_size, 0)
340
+ grad_input_split = []
341
+
342
+ with torch.set_grad_enabled(True):
343
+ model_copy.requires_grad_(True)
344
+ model_copy.zero_grad()
345
+ for cur_x, cur_grad_output in zip(x_split, grad_output_split):
346
+ cur_x.requires_grad_(True)
347
+ cur_y = model_copy(cur_x)
348
+ cur_y.backward(cur_grad_output)
349
+
350
+ grad_input_split.append(cur_x.grad.clone())
351
+
352
+ grad_input = torch.cat(grad_input_split, dim=0)
353
+
354
+ model_copy_params = list(model_copy.parameters())
355
+ model_params = list(model.parameters())
356
+
357
+ for param, param_copy in zip(model_params, model_copy_params):
358
+ if param.grad is None:
359
+ param.grad = param_copy.grad.clone()
360
+ else:
361
+ param.grad.add_(param_copy.grad)
362
+
363
+ return None, grad_input, None
364
+
365
+
366
+ def apply_batch_checkpointing(func, x, chunk_size):
367
+ if chunk_size >= len(x):
368
+ # return func(x)
369
+ return torch.utils.checkpoint.checkpoint(func, x, use_reentrant=False)
370
+
371
+ x_split = torch.split(x, chunk_size, dim=0)
372
+
373
+ def cat_and_query(y_all, x):
374
+ return torch.cat([y_all, func(x)])
375
+
376
+ y_all = func(x_split[0])
377
+ for cur_x in x_split[1:]:
378
+ y_all = torch.utils.checkpoint.checkpoint(
379
+ cat_and_query, y_all, cur_x, use_reentrant=False
380
+ )
381
+
382
+ return y_all
383
+
384
+
385
+ def get_encoding(n_input_dims: int, config) -> nn.Module:
386
+ raise NotImplementedError
387
+
388
+
389
+ def get_mlp(n_input_dims, n_output_dims, config) -> nn.Module:
390
+ raise NotImplementedError
3D_Stage/lrm/models/renderers/__init__.py ADDED
File without changes
3D_Stage/lrm/models/renderers/base.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ import lrm
7
+ from ..networks import MultiHeadMLP
8
+ from ..background.base import BaseBackground
9
+ from ..materials.base import BaseMaterial
10
+ from ...utils.base import BaseModule
11
+ from ...utils.typing import *
12
+
13
+
14
+ class BaseRenderer(BaseModule):
15
+ @dataclass
16
+ class Config(BaseModule.Config):
17
+ radius: float = 1.0
18
+
19
+ cfg: Config
20
+
21
+ def configure(
22
+ self,
23
+ decoder: MultiHeadMLP,
24
+ material: BaseMaterial,
25
+ background: BaseBackground,
26
+ ) -> None:
27
+ super().configure()
28
+
29
+ self.set_decoder(decoder)
30
+ self.set_material(material)
31
+ self.set_background(background)
32
+
33
+ # set up bounding box
34
+ self.bbox: Float[Tensor, "2 3"]
35
+ self.register_buffer(
36
+ "bbox",
37
+ torch.as_tensor(
38
+ [
39
+ [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius],
40
+ [self.cfg.radius, self.cfg.radius, self.cfg.radius],
41
+ ],
42
+ dtype=torch.float32,
43
+ ),
44
+ )
45
+
46
+ def forward(self, *args, **kwargs) -> Dict[str, Any]:
47
+ raise NotImplementedError
48
+
49
+ @property
50
+ def decoder(self) -> MultiHeadMLP:
51
+ return self.non_module("decoder")
52
+
53
+ @property
54
+ def material(self) -> BaseMaterial:
55
+ return self.non_module("material")
56
+
57
+ @property
58
+ def background(self) -> BaseBackground:
59
+ return self.non_module("background")
60
+
61
+ def set_decoder(self, decoder: MultiHeadMLP) -> None:
62
+ self.register_non_module("decoder", decoder)
63
+
64
+ def set_material(self, material: BaseMaterial) -> None:
65
+ self.register_non_module("material", material)
66
+
67
+ def set_background(self, background: BaseBackground) -> None:
68
+ self.register_non_module("background", background)